12

I saw a sample of code (too big to paste here) where the author used model.train_on_batch(in, out) instead of model.fit(in, out). The official documentation of Keras says:

Single gradient update over one batch of samples.

But I don't get it. Is it the same as fit(), but instead of doing many feed-forward and backprop steps, it does it once? Or am I wrong?

nbro
  • 15,395
  • 32
  • 113
  • 196
CerushDope
  • 445
  • 1
  • 5
  • 14
  • 1
    Possible duplicate of [What is the use of train\_on\_batch() in keras?](https://stackoverflow.com/questions/49100556/what-is-the-use-of-train-on-batch-in-keras) – nbro Oct 08 '19 at 22:35

3 Answers3

15

Yes, train_on_batch trains using a single batch only and once.

While fit trains many batches for many epochs. (Each batch causes an update in weights).

The idea of using train_on_batch is probably to do more things yourself between each batch.

Daniel Möller
  • 84,878
  • 18
  • 192
  • 214
  • Could you please add the reason we use train_on_batch? It is used when the whole data cannot fit into memory, so we break it into chunks(batches) and use train_on_batch() instead of fit() ? – Shashi Tunga Jun 21 '18 at 09:58
  • 3
    If your data doesn't fit memory, you will use: 1 - if the problem is in the model, use smaller batches // 2 - if the problem is the total numpy array, use a generator and `fit_generator`. Only use `train_on_batch` if you want to do things manually between batches. – Daniel Möller Jun 21 '18 at 13:45
3

It is used when we want to understand and do some custom changes after each batch training.

A more precide use case is with the GANs. You have to update discriminator but during update the GAN network you have to keep the discriminator untrainable. so you first train the discriminator and then train the gan keeping discriminator untrainable. see this for more understanding: https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3

TBhavnani
  • 721
  • 7
  • 12
0

The method fit of the model train the model for one pass through the data you gave it, however because of the limitations in memory (especially GPU memory), we can't train on a big number of samples at once, so we need to divide this data into small piece called mini-batches (or just batchs). The methode fit of keras models will do this data dividing for you and pass through all the data you gave it.

However, sometimes we need more complicated training procedure we want for example to randomly select new samples to put in the batch buffer each epoch (e.g. GAN training and Siamese CNNs training ...), in this cases we don't use the fancy an simple fit method but instead we use the train_on_batch method. To use this methode we generate a batch of inputs and a batch of outputs(labels) in each iteration and pass it to this method and it will train the model on the whole samples in the batch at once and gives us the loss and other metrics calculated with respect to the batch samples.

SELLAM
  • 71
  • 1
  • 4