5

I'm training a Keras model which sits in a Scikit pipeline with some preprocessing. The Keras model is defined as

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, Dropout
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.wrappers.scikit_learn import KerasRegressor
from sklearn.pipeline import make_pipeline


def create_model(X_train):
    inp = Input(shape=(X_train.shape[1],))
    x = Dense(150, activation="relu")(inp)
    x = Dropout(0.4)(x)
    mean = Dense(1, activation="linear")(x)
    train_model_1 = Model(inp, mean)
    adam = optimizers.Adam(lr=0.01)
    train_model_1.compile(loss=my_loss_function, optimizer=adam)
    return train_model_1


clf = KerasRegressor(build_fn=create_model, epochs=250, batch_size=64)

Which is then used in a Pipeline with

pipeline = make_pipeline(
                other_steps,
                clf(X_train)
            )


pipeline.fit(X_train, y_train)

I want to use EarlyStopping where the test data (X_test, y_test) is used to validate against. This would normally be straightforward with

callbacks=[EarlyStopping(monitor='val_loss', patience=5)]

train_model_1.fit(X_train, y_train,
                  validation_data=(X_test, y_test),
                  callbacks=callbacks,
                  )

But I can't figure out where this would go in the pipeline. What is the right way to structure this?

John F
  • 994
  • 10
  • 26

1 Answers1

3

Pipeline.fit has a keyword-argument parameter:

**fit_params : dict of string -> object

Parameters passed to the fit method of each step, where each parameter name is prefixed such that parameter p for step s has key s__p.

So something like pipeline.fit(x_train, y_train, kerasregressor__callbacks=callbacks) should work. (Check the name of your pipeline steps, e.g. with pipeline.steps; make_pipeline generates the names using the lowercased name of the class, but I'm not sure offhand whether that'll work right with keras.)

See also How to pass a parameter to only one part of a pipeline object in scikit learn?

Ben Reiniger
  • 10,517
  • 3
  • 16
  • 29