1

I want to stop the model using a custom callback if the val_accuracy is reducing after a certain number of steps (steps here mean training_examples/batch_size).

Here's my first attempt which works but doesn't actually stop the model training:

class CustomEarlyStopping(tf.keras.callbacks.Callback):
    def __init__(self, monitor, max_steps, mode='min', delta=0):
        super().__init__()
        self.monitor = monitor
        self.max_steps = max_steps
        self.mode = mode
        self.delta = delta
        self.wait = 0
        self.stopped_step = 0
        self.best = None

    def on_train_begin(self, logs=None):
        self.wait = 0
        self.stopped_step = 0
        self.best = None

    def on_train_batch_end(self, batch, logs=None):
        current = logs.get(self.monitor)
        if current is None:
            return

        if self.best is None:
            self.best = current
        elif (self.mode == 'min' and current < self.best - self.delta) or (self.mode == 'max' and current > self.best + self.delta):
            self.wait += 1
            if self.wait >= self.max_steps:
                self.stopped_step = self.steps_per_epoch
                self.model.stop_training = True
        else:
            self.wait = 0
            self.best = current

early_stopping = CustomEarlyStopping(monitor='val_accuracy', max_steps=100)
mank
  • 884
  • 1
  • 6
  • 17

0 Answers0