I have created a tensorflow dataset, made it repeatable, shuffled it, divided it into batches, and have constructed an iterator to get the next batch. But when I do this, sometimes the elements are repetitive (within and among batches), especially for small datasets. Why?
3 Answers
Unlike what stated in your own answer, no, shuffling and then repeating won't fix your problems.
The key source of your problem is that you batch, then shuffle/repeat. That way, the items in your batches will always be taken from contiguous samples in the input dataset. Batching should be one of the last operations you do in your input pipeline.
Expanding the question slightly.
Now, there is a difference in the order in which you shuffle, repeat and batch, but it's not what you think. Quoting from the input pipeline performance guide:
If the repeat transformation is applied before the shuffle transformation, then the epoch boundaries are blurred. That is, certain elements can be repeated before other elements appear even once. On the other hand, if the shuffle transformation is applied before the repeat transformation, then performance might slow down at the beginning of each epoch related to initialization of the internal state of the shuffle transformation. In other words, the former (repeat before shuffle) provides better performance, while the latter (shuffle before repeat) provides stronger ordering guarantees.
Recapping
- Repeat, then shuffle: you lose the guarantee that all samples are processed in one epoch.
- Shuffle, then repeat: it is guaranteed that all samples will be processed before the next repeat begins, but you lose (slightly) in performance.
Whichever you choose, do that before batching.
-
Note that you can use the order [shuffle, batch, repeat] as long as you use `reshuffle_each_iteration=True` when shuffling. – Kilian Obermeier Apr 05 '19 at 10:41
You must shuffle first, and then repeat!
As the following two codes show, the order of shuffling and repeating matters.
Worst ordering:
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.batch(2)
ds = ds.repeat()
ds = ds.shuffle(100000)
iterator = ds.make_one_shot_iterator()
next_batch = iterator.get_next()
with tf.Session() as sess:
for i in range(15):
if i % (10//2) == 0:
print("------------")
print("{:02d}:".format(i), next_batch.eval())
Output:
------------
00: [6 7]
01: [2 3]
02: [6 7]
03: [0 1]
04: [8 9]
------------
05: [6 7]
06: [4 5]
07: [6 7]
08: [4 5]
09: [0 1]
------------
10: [2 3]
11: [0 1]
12: [0 1]
13: [2 3]
14: [4 5]
Bad Ordering:
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.batch(2)
ds = ds.shuffle(100000)
ds = ds.repeat()
iterator = ds.make_one_shot_iterator()
next_batch = iterator.get_next()
with tf.Session() as sess:
for i in range(15):
if i % (10//2) == 0:
print("------------")
print("{:02d}:".format(i), next_batch.eval())
Output:
------------
00: [4 5]
01: [6 7]
02: [8 9]
03: [0 1]
04: [2 3]
------------
05: [0 1]
06: [4 5]
07: [8 9]
08: [2 3]
09: [6 7]
------------
10: [0 1]
11: [4 5]
12: [8 9]
13: [2 3]
14: [6 7]
Best Ordering:
Inspired by GPhilo answer, the order of batching also matter. For batches to be different in each epoch, one must shuffle first, then repeat, and finally batch. As it can be seen in the output, all batches are unique, unlike the other.
import tensorflow as tf
ds = tf.data.Dataset.range(10)
ds = ds.shuffle(100000)
ds = ds.repeat()
ds = ds.batch(2)
iterator = ds.make_one_shot_iterator()
next_batch = iterator.get_next()
with tf.Session() as sess:
for i in range(15):
if i % (10//2) == 0:
print("------------")
print("{:02d}:".format(i), next_batch.eval())
Output:
------------
00: [2 5]
01: [1 8]
02: [9 6]
03: [3 4]
04: [7 0]
------------
05: [4 3]
06: [0 2]
07: [1 9]
08: [6 5]
09: [8 7]
------------
10: [7 3]
11: [5 9]
12: [4 1]
13: [8 6]
14: [0 2]

- 4,270
- 1
- 27
- 34
-
..you do realize the output of the "wrong" and "right" orders is the same, right? – GPhilo Apr 19 '18 at 08:15
-
1@GPhilo, Sorry, copied the wrong code. Thanks for the correction. It is fixed now. – Miladiouss Apr 19 '18 at 08:32
-
Looking at your new output, I think you ran exactly in the distinction they make in the performance guide I quoted. Your epochs are not wrong, but since the buffer where you take samples from is randomly shuffled before the repeat occurs, you can end up picking the same sample twice before exhausting all samples from the previous iteration (because the buffer is never fully emptied). Note that, while this is more evident for small datasets, in practice this is hardly a problem during training on bigger datasets – GPhilo Apr 19 '18 at 08:39
-
@GPhilo, based on your answer (thank you!), I have added a new ordering. What I did in the first two should seriously be avoided no matter how big or small the dataset since batches remain identical. – Miladiouss Apr 19 '18 at 08:44
-
Yes, batching always goes after shuffling. Ideally, batching is the very last operation you do (possibly followed by a `prefetch` for better performance) in your pipeline, because it's not really related to the preprocessing of the samples, it simply puts stuff together to feed more samples at the same time to the network. – GPhilo Apr 19 '18 at 08:47
-
What is the difference between shuffle>batch>repeat and shuffle>repeat>batch? – Seymour Apr 13 '20 at 08:18
If you want the same behavior as Keras' .fit()
function for example, you can use:
dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.repeat(EPOCHS)
This will iterate through the dataset in the same way that .fit(epochs=EPOCHS, batch_size=BATCH_SIZE, shuffle=True)
would. A quick example (eager execution enabled for readability only, the behavior is the same in graph mode):
import numpy as np
import tensorflow as tf
tf.enable_eager_execution()
NUM_SAMPLES = 7
BATCH_SIZE = 3
EPOCHS = 2
# Create the dataset
x = np.array([[2 * i, 2 * i + 1] for i in range(NUM_SAMPLES)])
dataset = tf.data.Dataset.from_tensor_slices(x)
# Shuffle, batch and repeat the dataset
dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.repeat(EPOCHS)
# Iterate through the dataset
iterator = dataset.make_one_shot_iterator()
for batch in dataset:
print(batch.numpy(), end='\n\n')
prints
[[ 8 9]
[12 13]
[10 11]]
[[0 1]
[2 3]
[4 5]]
[[6 7]]
[[ 4 5]
[10 11]
[12 13]]
[[6 7]
[0 1]
[2 3]]
[[8 9]]
You can see that even though .batch()
was called after .shuffle()
, the batches are still different in every epoch. This is why we need to use reshuffle_each_iteration=True
. If we would not reshuffle at each iteration, we would get the same batches in every epoch:
[[12 13]
[ 4 5]
[10 11]]
[[6 7]
[8 9]
[0 1]]
[[2 3]]
[[12 13]
[ 4 5]
[10 11]]
[[6 7]
[8 9]
[0 1]]
[[2 3]]
This can be detrimental when training on small datasets.

- 6,678
- 4
- 38
- 50