12

How do you get all the hidden states from tf.nn.rnn() or tf.nn.dynamic_rnn() in TensorFlow? The API only gives me the final state.

The first alternative would be to write a loop when building a model that operates directly on RNNCell. However, the number of timesteps is not fixed for me, and depends on the incoming batch.

Some options are to either use a GRU or to write my own RNNCell that concatenates the state to the output. The former choice isn't general enough and the latter sounds too hacky.

Another option is to do something like the answers in this question, getting all the variables from an RNN. However, I'm not sure how to separate the hidden states from other variables in a standard fashion here.

Is there a nice way to get all the hidden states from an RNN while still using the library-provided RNN APIs?

Community
  • 1
  • 1
Ankit Vani
  • 121
  • 8
  • I've created a PR [here](https://github.com/tensorflow/tensorflow/pull/9995) and it might help you deal with simple cases – Carefree0910 May 27 '17 at 03:40

2 Answers2

2

tf.nn.dynamic_rnn(also tf.nn.static_rnn) has two return values; "outputs", "state" (https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)

As you said, "state" is the final state of RNN, but "outputs" are all hidden states of RNN(which shape is [batch_size, max_time, cell.output_size])

You can use "outputs" as hidden states of RNN, because in most library-provided RNNCell, "output" and "state" are same. (except LSTMCell)

Junyeop Lee
  • 237
  • 2
  • 10
  • Setting aside that this is specific to GRU, this doesn't help you if you have multiple layers, for instance if you wrap [GRUCell](https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/GRUCell) in a [MultiRNNCell](https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/MultiRNNCell). Your output will only contain states from the final layer. – Ankit Vani Nov 16 '17 at 21:43
0

I've already created a PR here and it might help you deal with simple cases

Let me briefly explain my implementation, so you can write your own version if you need. The main part is the modification of the _time_step function:

def _time_step(time, output_ta_t, state, *args):

The parameters remain the same except an extra *args is passed in. But why args? That's because I want to support tensorflow's customary behavior. You are able to return the final state only by simply ignoring the args parameter:

if states_ta is not None:
    # If you want to return all states, set `args` to be `states_ta`
    loop_vars = (time, output_ta, state, states_ta)
else:
    # If you want the final state only, ignore `args`
    loop_vars = (time, output_ta, state)

How to make use of it?

if args:
    args = tuple(
        ta.write(time, out) for ta, out in zip(args[0], [new_state])
    )

In fact this is just a modification of the following (original) codes:

output_ta_t = tuple(
    ta.write(time, out) for ta, out in zip(output_ta_t, output)
)

Now the args should contain all the states you want.

After all the works done above, you can pick up the states (or the final state) with following codes:

_, output_final_ta, *state_info = control_flow_ops.while_loop( ...

and

if states_ta is not None:
    final_state, states_final_ta = state_info
else:
    final_state, states_final_ta = state_info[0], None

Although I haven't tested it in complicated cases, it should work under 'simple' condition (here's my test cases)

  • Thanks for taking the time to compose an answer. In answer to your first sentence, it's better not to have duplicate information on Stack Overflow. Once you have 75 reputation, you will be able to flag one question as a duplicate of the other (though I might be wrong and maybe you can do that now). If the questions are not the same, it would be better to tailor each answer to suit the question's needs. – zondo May 31 '17 at 01:15
  • Thanks for your comment! I've already found some differences between those two questions so now I followed your advice and tailored each answer :) – Carefree0910 May 31 '17 at 01:28
  • The way I actually solved this was by creating a wrapper cell (like MultiRNNCell) that output the states concatenated with the outputs. One need only do a split after that to separate outputs from hidden states. – Ankit Vani May 31 '17 at 05:12