5

I'm currently working on a Keras tutorial for recurrent network training and I'm having trouble understanding the Stateful LSTM concept. To keep things as simple as possible, the sequences have the same length seq_length. As far as I get it, the input data is of shape (n_samples, seq_length, n_features) and we then train our LSTM on n_samples/M batches of size M as follows:

For each batch:

  1. Feed in the 2D-tensors (seq_length, n_features) and for each input 2D-tensor compute the gradient
  2. Sum these gradients to get the total gradient on the batch
  3. Backpropagate the gradient and update weights

In the tutorial's example, feeding in the 2D-tensors is feeding in a sequence of size seq_length of letters encoded as vectors of length n_features. However, the tutorial says that in the Keras implementation of LSTMs, the hidden state is not reset after a whole sequence (2D-tensor) is fed in, but after a batch of sequences is fed in to use more context.

Why does keeping the hidden state of the previous sequence and using it as initial hidden state for our current sequence improve the learning and the predictions on our test set, since that "previously learned" initial hidden state won't be available when making predictions ? Moreover, Keras' default behaviour is to shuffle input samples at the beginning of each epoch so the batch context is changed at each epoch. This behaviour seems contradictory to keeping the hidden state through a batch since batch context is random.

Seanny123
  • 8,776
  • 13
  • 68
  • 124
H.M.
  • 261
  • 3
  • 8

1 Answers1

6

LSTMs in Keras aren't stateful by default - each sequence starts with newly-reset states. By setting stateful=True in your recurrent layer, successive inputs in a batch don't reset the network state. This assumes that the sequences are actually successive, and it means that in a (very informal) sense, you're training on sequences of length batch_size * seq_length.

Why does keeping the hidden state of the previous sequence and using it as initial hidden state for our current sequence improve the learning and the predictions on our test set, since that "previously learned" initial hidden state won't be available when making predictions ?

In theory, it improves learning because a longer context can teach the network things about the distribution that are still relevant when testing on the individually shorter sequences. If the network is learning some probability distribution, that distribution should hold over different sequence lengths.

Moreover, Keras's default behaviour is to shuffle input samples at the beginning of each epoch so the batch context is changed at each epoch. This behaviour seems contradictory to keeping the hidden state through a batch since batch context is random.

I haven't checked, but I assume that when stateful=True, only batches are shuffled - not the sequences within them.

In general, when we give the network some initial state, we don't mean for that to be a universally better starting point. It just means that the network can take the information from previous sequences into account when training.

tao_oat
  • 1,011
  • 1
  • 15
  • 33
  • 2
    Ok, my bad I just found in the keras docs : https://keras.io/layers/recurrent/ i found out that the hidden state of the sample in i-th position within the batch will be fed as input hidden state for the sample in i-th position in the next batch. Does that mean that if we want to pass the hidden state from sample to sample we have to use batches of size 1 and therefore perform online gradient descent? Is there a way to pass the hidden state within a batch of size >1 and perform gradient descent on that batch ? – H.M. Jan 15 '17 at 20:30
  • 1
    Would love to know if you're right. So far I've been tripping up over the fact that if you have a batch_size > 1, it doesn't seem to make sense that the hidden state of the first element of the first batch gets passed to the first element of the second batch and so on.. Also makes me question whether the state is passed within a batch, because it'd be conflicting with the prior methodology. i.e. does the second element of the second batch inherit state from the second element of first batch or the first element of second batch? – Michael Du Apr 26 '18 at 22:30