How do you get all the hidden states from tf.nn.rnn()
or tf.nn.dynamic_rnn()
in TensorFlow? The API only gives me the final state.
The first alternative would be to write a loop when building a model that operates directly on RNNCell. However, the number of timesteps is not fixed for me, and depends on the incoming batch.
Some options are to either use a GRU or to write my own RNNCell that concatenates the state to the output. The former choice isn't general enough and the latter sounds too hacky.
Another option is to do something like the answers in this question, getting all the variables from an RNN. However, I'm not sure how to separate the hidden states from other variables in a standard fashion here.
Is there a nice way to get all the hidden states from an RNN while still using the library-provided RNN APIs?