By default, function dynamic_rnn
outputs only hidden states (known as m
) for each time point which can be obtained as follows:
cell = tf.contrib.rnn.LSTMCell(100)
rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
inputs=inputs,
sequence_length=sequence_lengths,
dtype=tf.float32)
Is there a way get intermediate (not final) cell states (c
) in addition?
A tensorflow
contributor mentions that it can be done with a cell wrapper:
class Wrapper(tf.nn.rnn_cell.RNNCell):
def __init__(self, inner_cell):
super(Wrapper, self).__init__()
self._inner_cell = inner_cell
@property
def state_size(self):
return self._inner_cell.state_size
@property
def output_size(self):
return (self._inner_cell.state_size, self._inner_cell.output_size)
def call(self, input, state)
output, next_state = self._inner_cell(input, state)
emit_output = (next_state, output)
return emit_output, next_state
However, it doesn't seem to work. Any ideas?