2

By default, function dynamic_rnn outputs only hidden states (known as m) for each time point which can be obtained as follows:

cell = tf.contrib.rnn.LSTMCell(100)
rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
                                   inputs=inputs,
                                   sequence_length=sequence_lengths,
                                   dtype=tf.float32)

Is there a way get intermediate (not final) cell states (c) in addition?

A tensorflow contributor mentions that it can be done with a cell wrapper:

class Wrapper(tf.nn.rnn_cell.RNNCell):
  def __init__(self, inner_cell):
     super(Wrapper, self).__init__()
     self._inner_cell = inner_cell
  @property
  def state_size(self):
     return self._inner_cell.state_size
  @property
  def output_size(self):
    return (self._inner_cell.state_size, self._inner_cell.output_size)
  def call(self, input, state)
    output, next_state = self._inner_cell(input, state)
    emit_output = (next_state, output)
    return emit_output, next_state

However, it doesn't seem to work. Any ideas?

Maxim
  • 52,561
  • 27
  • 155
  • 209
user2740947
  • 181
  • 9

2 Answers2

2

The proposed solution works for me, but Layer.call method spec is more general, so the following Wrapper should be more robust to API changes. Thy this:

class Wrapper(tf.nn.rnn_cell.RNNCell):
  def __init__(self, inner_cell):
     super(Wrapper, self).__init__()
     self._inner_cell = inner_cell

  @property
  def state_size(self):
     return self._inner_cell.state_size

  @property
  def output_size(self):
    return (self._inner_cell.state_size, self._inner_cell.output_size)

  def call(self, input, *args, **kwargs):
    output, next_state = self._inner_cell(input, *args, **kwargs)
    emit_output = (next_state, output)
    return emit_output, next_state

Here's the test:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False))
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
print(outputs, states)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val = outputs[0].eval(feed_dict={X: X_batch})
  print(outputs_val)

Returned outputs is the tuple of (?, 2, 10) and (?, 2, 5) tensors, which are all LSTM states and outputs. Note that I'm using the "graduated" version of LSTMCell, from tf.nn.rnn_cell package, not tf.contrib.rnn. Also note state_is_tuple=True to avoid dealing with LSTMStateTuple.

Maxim
  • 52,561
  • 27
  • 155
  • 209
0

Based on Maxim's idea, I ended up with the following solution:

class StatefulLSTMCell(LSTMCell):
    def __init__(self, *args, **kwargs):
        super(StatefulLSTMCell, self).__init__(*args, **kwargs)

    @property
    def output_size(self):
        return (self.state_size, super(StatefulLSTMCell, self).output_size)

    def call(self, input, state):
        output, next_state = super(StatefulLSTMCell, self).call(input, state)
        emit_output = (next_state, output)
        return emit_output, next_state
user2740947
  • 181
  • 9