I am trying to implement an attention mechanism where I need the full sequence of the cell state (just like the full sequence of the hidden state). Keras LSTM
however only returns the last cell state:
output, state_h, state_c = layers.LSTM(units=45, return_state=True, return_sequences=True)
state_c
has shape (batch size, 1, 45)
where output (which is the full sequence hidden state) has shape (batch size, 5, 45)
. 5 is the time window length
Why does Keras not return the full sequence cell state? and is there a better approach to get the full sequence of cell state than the approach below?
full_hidden, full_cell, outputs = [], [], []
state = None
input = layers.Input(shape=(time_window,features), dtype='float32')
output = layers.LSTM(units=45, return_state=True)
for i in range(time_window):
input_t = input[:, i, :]
input_t = tf.expand_dims(input_t, 1)
out, state_h, state_c = lstm(input_t, initial_state=state)
state = state_h, state_c
full_hidden.append(state_h)
full_cell.append(state_c)
outputs.append(out)