2

I am dealing with long sequential data which has to be passed to an RNN. To do truncated BPTT and batching, seems like there are two options:

  1. Create a batch by combining respective segments from different sequences. Preserve final state of each sequence in a batch and pass it on to next batch.
  2. Consider each sequence as a mini-batch with segments from the sequence becoming members of the batch. Preserve the state of the last time step in one segment and pass it on to the first time step of the next segment.

I came across tf.contrib.training.batch_sequences_with_states which seems to be doing one of the two. The documentation is confusing to me and hence I want to be certain which way does it generate the batches.

My guess is it does it the first way. That’s because, if the batching is being done the second way, then we cannot leverage the benefits of vectorization, since, to preserve the state between the last time step of one segement to the first time step of the next segment, RNN should process one token at a time sequentially.

Question:

Which of these two batching strategies are implemented in tf.contrib.training.batch_sequences_with_states?

Jonathon Reinhart
  • 132,704
  • 33
  • 254
  • 328
dragster
  • 448
  • 1
  • 4
  • 20

1 Answers1

2

tf.contrib.training.batch_sequences_with_states implements the former behavior. Each minibatch entry is a segment from a different sequence (each sequence, which can be composed of a variable number of segments, has a unique key and this key is passed into batch_sequences_with_states). When used with state_saving_rnn, the final state for each segment is saved back into a special storage container which allows the next segment of the given sequence to be run at the next sess.run. Final segments free up a minibatch slot for a different sequence.

Eugene Brevdo
  • 899
  • 7
  • 8