While training, I set epochs to number of times to iterate over the data. I was wondering what is the use of tf.data.Datasets.repeat(EPOCHS)
when I can already do the same thing with model.fit(train_dataset,epochs=EPOCHS)
?

- 165
- 1
- 9
2 Answers
It does work slightly different.
Let's pick 2 different examples.
- dataset.repeat(20) and model.fit(epochs=10)
- dataset.repeat(10) and model.fit(epochs=20)
Let's also assume that you have a dataset with 100 records.
If you pick option 1, each epoch will have 2,000 records. You will be "checking" how your model is improving after passing 2,000 records thru your model and you will do that 10 times.
If you choose option 2, each epoch will have 1,000 records. You will be evaluating how your model is improving after pushing 1,000 records and you will do that 20 times.
In both options, the total number of records that you will use for training is the same but the "time" when you evaluate, log, etc the behavior of your model is different.

- 1,030
- 1
- 11
- 15
-
could please you explain the last line in more detail? Will the final metric(like accuracy) of the model will change or remain same for 2 examples? i think it should remain the same – spb Apr 02 '21 at 06:17
-
if your model is the same and you don't have certain things like dropout layers which are supposed to introduce randomness while training and assuming that your batch size is also the same so that gradients would be the same, yes, the accuracy would be the same. The only difference is when you check how your training is progressing. makes sense? – CrazyBrazilian Apr 03 '21 at 16:36
tf.data.Datasets.repeat()
can be useful for data augmentation on tf.data.Datasets
in the case of image data .
Suppose you want to increase the number of images in the training dataset, using random transformations then repeating training dataset count
times and apply random transformations as shown below
train_dataset = (
train_dataset
.map(resize, num_parallel_calls=AUTOTUNE)
.map(rescale, num_parallel_calls=AUTOTUNE)
.map(onehot, num_parallel_calls=AUTOTUNE)
.shuffle(BUFFER_SIZE, reshuffle_each_iteration=True)
.batch(BATCH_SIZE)
.repeat(count=5)
.map(random_flip, num_parallel_calls=AUTOTUNE)
.map(random_rotate, num_parallel_calls=AUTOTUNE)
.prefetch(buffer_size=AUTOTUNE)
)
Without repeat() method you have to create copies of dataset, apply transformations seperately and then concatenate datasets. But using repeat() simplifies this, also takes advantage of method chaining and have a neat looking code.
More on data augmentation : https://www.tensorflow.org/tutorials/images/data_augmentation#apply_augmentation_to_a_dataset

- 165
- 1
- 9