1

What is the difference between batching your dataset with dataset.batch(batch_size) and batching your dataset with the batch_size parameter on the .fit the function of your model? do they have the same functionality or are they different?

Tomergt45
  • 579
  • 1
  • 7
  • 19
  • Duplicate of : https://stackoverflow.com/questions/62670041/batch-size-in-tf-model-fit-vs-batch-size-in-tf-data-dataset/62670148#62670148 – Timbus Calin Jul 02 '20 at 14:45
  • Does this answer your question? [batch\_size in tf model.fit() vs. batch\_size in tf.data.Dataset](https://stackoverflow.com/questions/62670041/batch-size-in-tf-model-fit-vs-batch-size-in-tf-data-dataset) – Timbus Calin Jul 02 '20 at 14:46

2 Answers2

2

Check the documentation for the parameter batch_size in fit:

batch_size
Integer or None. Number of samples per gradient update. If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).

So, if you are passing a dataset object for training, do not use the batch_size parameter, as that is only meant for the case where your X/Y values are NumPy arrays or TensorFlow tensors.

jdehesa
  • 58,456
  • 7
  • 77
  • 121
0

Using dataset.batch() combines consecutive elements of a dataset object into batches. For example:

>> dataset = tf.data.Dataset.range(8)
>> dataset = dataset.batch(3)
>> list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]

This modifies your dataset object itself. In many instances, this may not be desirable. I would recommend using this function solely for manipulation of the dataset as a step in preprocessing.

Moreover, combining a dataset object with specifying the batch_size parameter in fit will throw an error.

The batch_size parameter should be used when numpy arrays or tensorflow tensors as passed as inputs to fit.

Examples taken from official tensorflow documentation which can be found at the links provided below.

dataset.batch() - https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch

Parth Shah
  • 1,237
  • 10
  • 24