65

I use the following code when training a model in keras

from keras.callbacks import EarlyStopping

model = Sequential()
model.add(Dense(100, activation='relu', input_shape = input_shape))
model.add(Dense(1))

model_2.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])


model.fit(X, y, epochs=15, validation_split=0.4, callbacks=[early_stopping_monitor], verbose=False)

model.predict(X_test)

but recently I wanted to get the best trained model saved as the data I am training on gives a lot of peaks in "high val_loss vs epochs" graph and I want to use the best one possible yet from the model.

Is there any method or function to help with that?

gosuto
  • 5,422
  • 6
  • 36
  • 57
dJOKER_dUMMY
  • 699
  • 2
  • 6
  • 5

3 Answers3

93

EarlyStopping and ModelCheckpoint is what you need from Keras documentation.

You should set save_best_only=True in ModelCheckpoint. If any other adjustments needed, are trivial.

Just to help you more you can see a usage here on Kaggle.


Adding the code here in case the above Kaggle example link is not available:

model = getModel()
model.summary()

batch_size = 32

earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save = ModelCheckpoint('.mdl_wts.hdf5', save_best_only=True, monitor='val_loss', mode='min')
reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=7, verbose=1, epsilon=1e-4, mode='min')

model.fit(Xtr_more, Ytr_more, batch_size=batch_size, epochs=50, verbose=0, callbacks=[earlyStopping, mcp_save, reduce_lr_loss], validation_split=0.25)
Shridhar R Kulkarni
  • 6,653
  • 3
  • 37
  • 57
  • 10
    Could you please attach the example here? If the link brokes, the answer will become useles – xenteros Oct 30 '18 at 05:38
  • I am getting a `Keyerror: 'lr' ` because of ReduceLROnPlateau . Why ?? – Neeraj Kumar Jul 21 '19 at 14:17
  • @NeerajKumar: Please read https://realpython.com/python-keyerror/ and https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ReduceLROnPlateau. If reading these docs didn't help you, then please post a separate question along with the part of the code that is causing an error. – Shridhar R Kulkarni Jul 21 '19 at 18:41
47

EarlyStopping's restore_best_weights argument will do the trick:

restore_best_weights: whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used.

So not sure how your early_stopping_monitor is defined, but going with all the default settings and seeing you already imported EarlyStopping you could do this:

early_stopping_monitor = EarlyStopping(
    monitor='val_loss',
    min_delta=0,
    patience=0,
    verbose=0,
    mode='auto',
    baseline=None,
    restore_best_weights=True
)

And then just call model.fit() with callbacks=[early_stopping_monitor] like you already do.

gosuto
  • 5,422
  • 6
  • 36
  • 57
  • 3
    This is the right answer. The reason why other answers have more votes is probably due to the fact that `restore_best_weights` has been introduced in [Keras 2.2.3](https://github.com/keras-team/keras/releases/tag/2.2.3), which has been released on October 2018, i.e. after [this answer](https://stackoverflow.com/a/48286003/738017). – Vito Gentile Sep 25 '20 at 14:43
  • 1
    Man... Keras is just *too* easy! – Ulf Aslak Oct 02 '20 at 11:09
  • 1
    will it still restore the best weights when it didn't manage to earlystop but finish all epochs – noone Jan 01 '22 at 21:48
  • please note currently, `restore_best_weights=False` by default! – Ali Pardhan Apr 13 '22 at 01:46
15

I guess model_2.compile was a typo. This should help if you want to save the best model w.r.t to the val_losses -

checkpoint = ModelCheckpoint('model-{epoch:03d}-{acc:03f}-{val_acc:03f}.h5', 
    verbose=1, 
    monitor='val_loss',
    save_best_only=True, 
    mode='auto'
)  

model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])

model.fit(X, y, epochs=15, validation_split=0.4, callbacks=[checkpoint], verbose=False)
Pallav Jha
  • 3,409
  • 3
  • 29
  • 52
Vivek
  • 322
  • 1
  • 13