Using dynamic_rnn on a sequence of input currently returns you a sequence of outputs and the last cell-state.
For the task at hand (a truncated back-prop that can start/end at any index in the sequence) I need to access not just the last cell-state but I need access to the entire sequence of intermediate states as well. A quick online look I found the following thread:
https://github.com/tensorflow/tensorflow/issues/5731
The thread suggested that the best thing to do is extending the original LSTM cell's functionality by having it return the whole state as part of the output, thus the first return value of calling dynamic_rnn would contain the sequence of output as well.
edit 3:
After 4 hours of poking around and finding solutions online, it seems I need to update the output_size property. This is the updated code:
class SSLSTMCell(tf.contrib.rnn.LSTMCell):
@property
def output_size(self):
return self.state_size * 2
def __call__(self, inputs, state):
cell_out, cell_state = super(SSLSTMCell, self).__call__(inputs, state)
together = tf.concat([cell_state.c, cell_state.h], axis=1)
return together, cell_state
I take to achieve a similar result for the stackedLSTM one would do the same trick, modifying both the __call__ function and also modifying the output_size.
I have a real problem with doing it this way however, now the "output" is a concatenation of everything together and lost the structure of the tuples, as I cannot see a way to get it to work while preserving the tuple structure. I understand I can perform some re-shaping to get them back into the tuple form which can then be used again as initial states to the stackedLSTM, but if there is anyway to preserve the tuple structure it would be great.
Older edits that may provide context but may no longer be relevant to the discussion
This is what I have done so far:
class SSLSTMCell(tf.contrib.rnn.LSTMCell): def call(self, inputs, state): cell_out, cell_state = super(SSLSTMCell, self).call(inputs, state) return cell_state, cell_state
As you can see, instead of outputting the (output, state) tuple at each time-step I'm simply outputting the (state, state) output, which would give me the capability of accessing all the intermediate states.
However, it doesn't seem like the "call" function of my custom sub-class SSLSTMCell is being invoked at all during the dynamic_rnn function call, indeed when I try to put an "assert 0" in the body of the call function my program doesn't crash.
So what could be wrong? I looked up the implementation on dynamic_rnn the "call"function is definitely being used, but somehow it doesn't use the one defined by my custom sub-class
Thanks in advance.
edit 1:
It seems I should put some under-scores on the "call" function, I am not sure why but this is my updated code:
class SSLSTMCell(tf.contrib.rnn.LSTMCell): def __call__(self, inputs, state): cell_out, cell_state = super(SSLSTMCell, self).__call__(inputs, state) return cell_state, cell_state
This worked out better as the "__call__" function is being called this time (putting an assert 0 inside this function crashes python). However, I am getting a different error:
File "/home/evan/tensorflow/local/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py",
line 655, in call cur_inp, new_state = cell(cur_inp, cur_state) File "try1.py", line 10, in call cell_out, cell_state = super(SSLSTMCell, self).call(inputs, state) File "/home/evan/tensorflow/local/lib/python2.7/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py", line 327, in call input_size = inputs.get_shape().with_rank(2)[1] AttributeError: 'LSTMStateTuple' object has no attribute 'get_shape'
edit 2:
It seems the output must be a regular tensor for the next step of the computation, and a LSTM-Tuple-State is a tuple of tensors. So I tried putting them together
class SSLSTMCell(tf.contrib.rnn.LSTMCell): def __call__(self, inputs, state): cell_out, cell_state = super(SSLSTMCell, self).__call__(inputs, state) together = tf.concat([cell_state.c, cell_state.h], axis=1) return together, cell_state
However now I have a different error:
ValueError: Dimension 1 in both shapes must be equal, but are 10 and 20 for 'rnn/while/Select' (op: 'Select') with input shapes: [?],
Clearly the framework does not expect an output to be suddenly bigger... But I do not understand why that should be the case, shouldn't ouput simply be output and has no bearing on the computation? This is very confusing why "output" should affect the computation.
Is this simply a bad idea of trying to extend the LSTMCell class to do what I wanted to do? I do like the interfact to dynamic_rnn but if I could just get to the intermediate states...