I am dealing with long sequential data which has to be passed to an RNN. To do truncated BPTT and batching, seems like there are two options:
- Create a batch by combining respective segments from different sequences. Preserve final state of each sequence in a batch and pass it on to next batch.
- Consider each sequence as a mini-batch with segments from the sequence becoming members of the batch. Preserve the state of the last time step in one segment and pass it on to the first time step of the next segment.
I came across tf.contrib.training.batch_sequences_with_states
which seems to be doing one of the two. The documentation is confusing to me and hence I want to be certain which way does it generate the batches.
My guess is it does it the first way. That’s because, if the batching is being done the second way, then we cannot leverage the benefits of vectorization, since, to preserve the state between the last time step of one segement to the first time step of the next segment, RNN should process one token at a time sequentially.
Question:
Which of these two batching strategies are implemented in tf.contrib.training.batch_sequences_with_states
?