4

I'm working with Keras, and trying to create a Learning Rate Scheduler that schedules on the basis of number of batches processed, instead of number of epochs. To do this, I've inserted the scheduling code into the get_updates method of my `Optimizer'. For the most part, I've tried to use regular Python variables for values that remain constant during a given training run and computational graph nodes only for parameters that actually vary.

My 2 Questions are:

  1. Does the code below look like it should behave properly as a Learning Rate Scheduler, if placed within the get_updates method of a Keras Optimizer.

  2. How could one embed this code in a Class similar to LearningRateScheduler, but which scheduled based upon number of batches, rather than number of epochs?


    #Copying graph node that stores original value of learning rate
    lr = self.lr 

    # Checking whether learning rate schedule is to be used
    if self.initial_lr_decay > 0:
        # this decay mimics exponential decay from 
        # tensorflow/python/keras/optimizer_v2/exponential_decay 

        # Get value of current number of processed batches from graph node
        # and convert to numeric value for use in K.pow()
        curr_batch = float(K.get_value(self.iterations))

        # Create graph node containing lr decay factor
        # Note: self.lr_decay_steps is a number, not a node
        #       self.lr_decay is a node, not a number
        decay_factor =  K.pow(self.lr_decay, (curr_batch / self.lr_decay_steps)) 

        # Reassign lr to graph node formed by
        # product of graph node containing decay factor
        # and graph node containing original learning rate.
        lr = lr * decay_factor

        # Get product of two numbers to calculate number of batches processed
        # in warmup period
        num_warmup_batches = self.steps_per_epoch_num * self.warmup_epochs

        # Make comparisons between numbers to determine if we're in warmup period
        if (self.warmup_epochs > 0) and (curr_batch < num_warmup_batches):

            # Create node with value of learning rate by multiplying a number
            # by a node, and then dividing by a number
            lr = (self.initial_lr  *
                  K.cast(self.iterations, K.floatx()) / curr_batch)
user1245262
  • 6,968
  • 8
  • 50
  • 77

1 Answers1

3

Easier than messing with Keras source code (it's possible, but it's complex and sensible), you could use a callback.

from keras.callbacks import LambdaCallback

total_batches = 0
def what_to_do_when_batch_ends(batch, logs):
   total_batches += 1 #or use the "batch" variable,
                      #which is the batch index of the last finished batch

   #change learning rate at will
   if your_condition == True:
       keras.backend.set_value(model.optimizer.lr, newLrValueAsPythonFloat)

When training, use the callback:

lrUpdater = LambdaCallback(on_batch_end = what_to_do_when_batch_ends)
model.fit(........, callbacks = [lrUpdater, ...other callbacks...])
Daniel Möller
  • 84,878
  • 18
  • 192
  • 214
  • - Thanks, I'll read up on LambdaCallback. But can't I just use model.optimizer.iterations to keep track of my batch? – user1245262 Apr 07 '20 at 02:59
  • During training it isn't very easy, it's a symbolic tensor, you can't evaluate it, then you will need to use conditions based on tensor functions instead of ifs, and an update function. It's complex. Possible, but complex. – Daniel Möller Apr 07 '20 at 03:01
  • Hi Daniel, I encountered the same issue here(want to make the learning rate decay based upon batch not epoch). And the solution in your answer is clearly. But I have one more question about the "model" variable you used to get learning rate in "keras.backend.set_value(model.optimizer.lr, newLrValueAsPythonFloat)”. Is it doable to call model in a function not a callback. I mean in callback, I can call model from self.model, but in the function i can model directly, there will be an error about unresolved "model". Looking forward to your reply. Thanks in advance. – Summer May 25 '21 at 02:40
  • sorry, the error is not unresolved model but name 'model' is not defined. – Summer May 25 '21 at 05:18
  • If you don't have your `model` defined, use what you defined as model. As you can see, the lambda callback doesn't take "self". So you either choose to create your own callback, and then take `self.model` from the custom callback you created, or you can use the variable where you defined your model. With some luck, you can take the model from the callback with `lrUpdater.model`, never tested, might have quirks. – Daniel Möller May 27 '21 at 19:45