I want to stop the model using a custom callback if the val_accuracy
is reducing after a certain number of steps (steps here mean training_examples/batch_size).
Here's my first attempt which works but doesn't actually stop the model training:
class CustomEarlyStopping(tf.keras.callbacks.Callback):
def __init__(self, monitor, max_steps, mode='min', delta=0):
super().__init__()
self.monitor = monitor
self.max_steps = max_steps
self.mode = mode
self.delta = delta
self.wait = 0
self.stopped_step = 0
self.best = None
def on_train_begin(self, logs=None):
self.wait = 0
self.stopped_step = 0
self.best = None
def on_train_batch_end(self, batch, logs=None):
current = logs.get(self.monitor)
if current is None:
return
if self.best is None:
self.best = current
elif (self.mode == 'min' and current < self.best - self.delta) or (self.mode == 'max' and current > self.best + self.delta):
self.wait += 1
if self.wait >= self.max_steps:
self.stopped_step = self.steps_per_epoch
self.model.stop_training = True
else:
self.wait = 0
self.best = current
early_stopping = CustomEarlyStopping(monitor='val_accuracy', max_steps=100)