1

This is my question: How can I use early stopping in my code? Which part should I put it?

callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10,mode="auto")]

My code:


numpy.random import seed
seed(1)

def create_model(optimizer='rmsprop'):
    model = Sequential()
    model.add(LSTM(50, activation='relu', return_sequences=True))
    model.add(LSTM(50, activation='relu'))
    model.add(Dense(1))

    model.compile(loss='mse',optimizer = optimizer)

    return model

clf = KerasRegressor(build_fn=create_model,epochs = 500,callbacks=[tf.keras.callbacks.EarlyStopping( patience=10)])

param_grid = {
'clf__optimizer' : ['adam','rmsprop'],
'clf__batch_size' : [500,45,77]
}

pipeline = Pipeline([
('clf',clf)
])

from sklearn.model_selection import TimeSeriesSplit, GridSearchCV

tscv = TimeSeriesSplit(n_splits=5)

grid = GridSearchCV(pipeline, cv=tscv,param_grid=param_grid,return_train_score=True,verbose=10,
scoring = 'neg_mean_squared_error')

grid.fit(Xtrain2,ytrain.values)

grid.cv_results_

I put callbacks in 'grid.fit' and also in 'param_grid' but I got error!!!

Ben Reiniger
  • 10,517
  • 3
  • 16
  • 29

2 Answers2

0

You need to use train the keras model directly using model.fit() function and you will see it allows you to pass in a callbacks parameter

York Yang
  • 527
  • 7
  • 13
  • 3
    I don't have problem use early stopping in `model.fit()` ,but I insist on using cross-validation to avoid over fitting so I have to use `gridsearchcv` .Is there any ways to use both `gridsearchcv` and `early_stopping` ?? – Shadi Mohebbi Jun 22 '20 at 15:32
0

Callbacks are specified in KerasRegressor.fit (docs), and GridSearchCV.fit admits fit_params keyword arguments. From docs:

**fit_params : dict of str -> object

Parameters passed to the fit method of the estimator

So something along the lines of

grid.fit(Xtrain2, ytrain.values, callbacks=[...])

should generally work. In your case, because you've embedded inside a pipeline, you need to additionally scope to the model, as

grid.fit(Xtrain2, ytrain.values, clf__callbacks=[...])

See also Can I send callbacks to a KerasClassifier?, though there are a lot of other issues in that question.

Ben Reiniger
  • 10,517
  • 3
  • 16
  • 29