How will I configure Keras to stop training until convergence or when the loss is 0? I intendedly want to overfit it. I don't want to set number of epochs. I just wanted it to stop when it converges.

- 57,590
- 26
- 140
- 166

- 2,214
- 6
- 33
- 80
-
1Then just do a while loop `while loss > 1e-10:` for instance? – Matthieu Brucher Nov 26 '18 at 10:02
3 Answers
Use an EarlyStopping callback. You may freely choose which loss/metric to observe and when to stop.
Usually, you would look at the "validation loss" (val_loss
), as this is the most important variable that tells that your model is still learning to generalize.
But since you said you want to overfit, then you may look at the "training loss" (loss
).
The callback works with "deltas", not with absolute values, which is good, because the loss doesn't necessarily have "zero" as its goal. But you can use the baseline
argument for setting absolute values.
So, usually, a callback that looks at the validation loss:
from keras.callbacks import EarlyStopping
usualCallback = EarlyStopping()
This is the same as EarlyStopping(monitor='val_loss', min_delta=0, patience=0)
One that will overfit:
overfitCallback = EarlyStopping(monitor='loss', min_delta=0, patience = 20)
Watch out for the patience
argument, it's important as the loss value doesn't always decrease at every epoch. Let the model keep trying for a few more epochs before ending.
Finally, just pass the callback to fit
along with a huge number of epochs:
model.fit(X, Y, epochs=100000000, callbacks=[overfitCallback])

- 7
- 4

- 84,878
- 18
- 192
- 214
An EarlyStopping will do exactly what you want: it helps you to stop the training when the monitored quantity (loss) has stopped improving. This is done using the patience parameter giving the number of epochs after which, if no improvement is noticed (~possible convergence), the training should stop. EarlyStopping usage information can be also found at the possible duplicate.
What is helpful is also to vizualize the training process.

- 919
- 8
- 13
If you want to manually stop keras, use mouse position as input:
def queryMousePosition():
from ctypes import windll, Structure, c_long, byref
class POINT(Structure): _fields_ = [("x", c_long), ("y", c_long)]
pt = POINT()
windll.user32.GetCursorPos(byref(pt))
return pt.x, pt.y # %timeit queryMousePosition()
class TerminateOnFlag(keras.callbacks.Callback):
def on_batch_end(self, batch, logs=None):
mouse_x, mouse_y = queryMousePosition()
if mouse_x < 10:
self.model.stop_training = True
callbacks=[keras.callbacks.ReduceLROnPlateau(), TerminateOnFlag()]
model.fit_generator(..., callbacks=callbacks, ...)

- 3,350
- 1
- 23
- 24