How train_on_batch()
is different from fit()
? What are the cases when we should use train_on_batch()
?

- 15,395
- 32
- 113
- 196

- 661
- 1
- 9
- 15
-
1Does this answer your question? [What does train\_on\_batch() do in keras model?](https://stackoverflow.com/questions/48550201/what-does-train-on-batch-do-in-keras-model) – nbro Jan 17 '20 at 15:11
5 Answers
For this question, it's a simple answer from the primary author:
With
fit_generator
, you can use a generator for the validation data as well. In general, I would recommend usingfit_generator
, but usingtrain_on_batch
works fine too. These methods only exist for the sake of convenience in different use cases, there is no "correct" method.
train_on_batch
allows you to expressly update weights based on a collection of samples you provide, without regard to any fixed batch size. You would use this in cases when that is what you want: to train on an explicit collection of samples. You could use that approach to maintain your own iteration over multiple batches of a traditional training set but allowing fit
or fit_generator
to iterate batches for you is likely simpler.
One case when it might be nice to use train_on_batch
is for updating a pre-trained model on a single new batch of samples. Suppose you've already trained and deployed a model, and sometime later you've received a new set of training samples previously never used. You could use train_on_batch
to directly update the existing model only on those samples. Other methods can do this too, but it is rather explicit to use train_on_batch
for this case.
Apart from special cases like this (either where you have some pedagogical reason to maintain your own cursor across different training batches, or else for some type of semi-online training update on a special batch), it is probably better to just always use fit
(for data that fits in memory) or fit_generator
(for streaming batches of data as a generator).
-
-
1Why can't you just use fit() for this? Adjusting the batch_size argument we can also do one weight update iteration passing the new batch on samples. Is there any difference? – hirschme Jun 23 '22 at 20:34
-
yeah I dont see why we cant use fit method itself for that new sample that just came in. Or does keras reset weights whenever someone calls fit method? I dont think so afaik – theprogrammer Dec 04 '22 at 20:15
-
To use `fit` you would need to provide a batch size that matches your ad hoc batch. Usually `fit` is buried in some code that abstracts the parameters out into some other environment (either for experimentation like in a notebook, or some config like for scheduled training jobs), so it could be a pain to override the batch size. (Omitting the batch size for `fit` will default to `32`). `train_on_batch` does not require a batch size. So for example if you already trained a model, don't want to mess with configs, but want to finetune or test training code on one batch, `train_on_batch` is better – ely May 07 '23 at 18:54
train_on_batch()
gives you greater control of the state of the LSTM, for example, when using a stateful LSTM and controlling calls to model.reset_states()
is needed. You may have multi-series data and need to reset the state after each series, which you can do with train_on_batch()
, but if you used .fit()
then the network would be trained on all the series of data without resetting the state. There's no right or wrong, it depends on what data you're using, and how you want the network to behave.

- 1,754
- 1
- 19
- 27
-
3Exactly my use case, I was searching the question to see if it made sense to do it this way as I was having a hell of a time trying to force it with `fit`. – adamconkey Apr 01 '19 at 23:25
Train_on_batch will also see a performance increase over fit and fit generator if youre using large datasets and don't have easily serializable data (like high rank numpy arrays), to write to tfrecords.
In this case you can save the arrays as numpy files and load up smaller subsets of them (traina.npy, trainb.npy etc) in memory, when the whole set won't fit in memory. You can then use tf.data.Dataset.from_tensor_slices and then using train_on_batch with your subdataset, then loading up another dataset and calling train on batch again, etc, now you've trained on your entire set and can control exactly how much and what of your dataset trains your model. You can then define your own epochs, batch sizes, etc with simple loops and functions to grab from your dataset.

- 41
- 1
-
2this `train_on_batch` in important for RL, which trains 1 step at a time, `fit` will be extremely slow – Dee Mar 19 '21 at 07:01
From Keras - Model training APIs:
- fit: Trains the model for a fixed number of epochs (iterations on a dataset).
- train_on_batch: Runs a single gradient update on a single batch of data.
We can use it in GAN when we update the discriminator and generator using a batch of our training data set at a time. I saw Jason Brownlee used train_on_batch in on his tutorials (How to Develop a 1D Generative Adversarial Network From Scratch in Keras)
Tip for quick search: Type Control+F and type in the search box the term that you want to search (train_on_batch, for example).

- 101
- 6
Indeed @nbro answer helps, just to add few more scenarios, lets say you are training some seq to seq model or a large network with one or more encoders. We can create custom training loops using train_on_batch and use a part of our data to validate on the encoder directly without using callbacks. Writing callbacks for a complex validation process could be difficult. There are several cases where we wish to train on batch.
Regards, Karthick

- 11
- 1