3

I am trying to implement an attention mechanism where I need the full sequence of the cell state (just like the full sequence of the hidden state). Keras LSTM however only returns the last cell state:

output, state_h, state_c = layers.LSTM(units=45, return_state=True, return_sequences=True)

state_c has shape (batch size, 1, 45) where output (which is the full sequence hidden state) has shape (batch size, 5, 45). 5 is the time window length

Why does Keras not return the full sequence cell state? and is there a better approach to get the full sequence of cell state than the approach below?

full_hidden, full_cell, outputs = [], [], []
state = None
input = layers.Input(shape=(time_window,features), dtype='float32')
output = layers.LSTM(units=45, return_state=True)

for i in range(time_window):
    input_t = input[:, i, :]
    input_t = tf.expand_dims(input_t, 1)
    out, state_h, state_c = lstm(input_t, initial_state=state)
    state = state_h, state_c
    full_hidden.append(state_h)
    full_cell.append(state_c)
    outputs.append(out)
bcsta
  • 1,963
  • 3
  • 22
  • 61

1 Answers1

-1

You need to set the flag return_sequences to True to get all the temporal states. The flag return_state=True that you use makes the layer to return the final state.

Jindřich
  • 10,270
  • 2
  • 23
  • 44