1

I know about Keras learning rate scheduler, and tf.keras.optimizers.schedules.InverseTimeDecay, but they only take the current epoch or only current step as argument, what I would like is for my learning rate to stay initial up to the tenth epoch for example and then start applying an inverse time decay scheduler. Is there a way to get both the epoch and current step as argument. I have tried this from the source code of Tensorflow as a way to easily increment step and epoch counter, and then pass it as a callback but it doesn't seem to change the learning rate:

class CustomLearningRateScheduler(tf.keras.callbacks.Callback):

    def __init__(self, initialLearnRate, decay):
        super(CustomLearningRateScheduler, self).__init__()
        self.initialLearnRate = initialLearnRate
        self.totalBatches=0
        self.decay=decay

    #def on_epoch_begin(self, epoch, logs=None):
    def on_train_batch_end(self, batch, logs=None):
        if not hasattr(self.model.optimizer, "lr"):
            raise ValueError('Optimizer must have a "lr" attribute.')
        
        #pdb.set_trace()
        # Get the current learning rate from model's optimizer.
        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        # Call schedule function to get the scheduled learning rate.
        scheduled_lr = self.initialLearnRate/(self.decay*self.totalBatches+1)
        self.totalBatches=self.totalBatches+1
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
        #print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))

lrSchedule = CustomLearningRateScheduler(initialLearnRate,decay_rate)
#start of training for 30 epoch
modified_vgg.fit(x=train_batches_seq,
    steps_per_epoch=len(train_batches_seq),
    validation_data=valid_batches_seq,
    validation_steps=len(valid_batches_seq),
    epochs=epochNummber,
    callbacks=[lrSchedule],
    verbose=1
)
  • `but it doesn't seem to change the learning rate:` Could you maybe add a plot of the learning rate evolution? Just to confirm that the lr is not changing. – Lescurel Jul 02 '21 at 13:26

0 Answers0