9

How will I configure Keras to stop training until convergence or when the loss is 0? I intendedly want to overfit it. I don't want to set number of epochs. I just wanted it to stop when it converges.

desertnaut
  • 57,590
  • 26
  • 140
  • 166
alyssaeliyah
  • 2,214
  • 6
  • 33
  • 80

3 Answers3

15

Use an EarlyStopping callback. You may freely choose which loss/metric to observe and when to stop.

Usually, you would look at the "validation loss" (val_loss), as this is the most important variable that tells that your model is still learning to generalize.

But since you said you want to overfit, then you may look at the "training loss" (loss).

The callback works with "deltas", not with absolute values, which is good, because the loss doesn't necessarily have "zero" as its goal. But you can use the baseline argument for setting absolute values.

So, usually, a callback that looks at the validation loss:

from keras.callbacks import EarlyStopping
usualCallback = EarlyStopping()

This is the same as EarlyStopping(monitor='val_loss', min_delta=0, patience=0)

One that will overfit:

overfitCallback = EarlyStopping(monitor='loss', min_delta=0, patience = 20)

Watch out for the patience argument, it's important as the loss value doesn't always decrease at every epoch. Let the model keep trying for a few more epochs before ending.

Finally, just pass the callback to fit along with a huge number of epochs:

model.fit(X, Y, epochs=100000000, callbacks=[overfitCallback])
Daniel Möller
  • 84,878
  • 18
  • 192
  • 214
2

An EarlyStopping will do exactly what you want: it helps you to stop the training when the monitored quantity (loss) has stopped improving. This is done using the patience parameter giving the number of epochs after which, if no improvement is noticed (~possible convergence), the training should stop. EarlyStopping usage information can be also found at the possible duplicate.

What is helpful is also to vizualize the training process.

marilena.oita
  • 919
  • 8
  • 13
1

If you want to manually stop keras, use mouse position as input:

def queryMousePosition():
    from ctypes import windll, Structure, c_long, byref
    class POINT(Structure): _fields_ = [("x", c_long), ("y", c_long)]
    pt = POINT()
    windll.user32.GetCursorPos(byref(pt))
    return pt.x, pt.y  # %timeit queryMousePosition()


class TerminateOnFlag(keras.callbacks.Callback):
    def on_batch_end(self, batch, logs=None):
        mouse_x, mouse_y = queryMousePosition()
        if mouse_x < 10:
            self.model.stop_training = True

callbacks=[keras.callbacks.ReduceLROnPlateau(), TerminateOnFlag()]

model.fit_generator(..., callbacks=callbacks, ...)
Mendi Barel
  • 3,350
  • 1
  • 23
  • 24