1

I have made two instances of the same custom model in Tensorflow 2.9.1 (i.e., model = Model() and ema_model = Model()). During the training of model in a custom loop, I want to calculate its EMA and update the ema_model with these variables.

Having checked this solution and also using ema_model.set_weights(model.get_weights()), my attempts were not successful. To be specific, I used them right after the optimization in the train_step function.

In other words, I want the parameters of the model follow the normal training, while the parameters of the ema_model are updated as the decayed version of the model.

Any hits/solution to this problem?

Cesar
  • 69
  • 1
  • 3

1 Answers1

0

I am trying out the same thing. Here's the solution I have come up with:

class EMA(tf.keras.callbacks.Callback):
    def __init__(self, decay=0.996):
        super(EMA, self).__init__()
        self.decay = decay

        # Create an ExponentialMovingAverage object
        self.ema = tf.train.ExponentialMovingAverage(decay=self.decay)

    def on_train_begin(self, logs=None):
        self.ema.apply(self.model.get_layer('anchor_model').trainable_variables)

    def on_train_batch_end(self, batch, logs=None):
        # Get exponential moving average of anchor model weights.
        train_vars = self.model.get_layer('anchor_model').trainable_variables
        averages = [self.ema.average(var) for var in train_vars]

        # Assign the average weights to target model
        target_model_vars = self.model.get_layer('target_model').non_trainable_variables
        assert len(target_model_vars) == len(averages)
        for i, var in enumerate(target_model_vars):
            var.assign(averages[i])

        self.ema.apply(self.model.get_layer('anchor_model').trainable_variables)
ayush thakur
  • 438
  • 3
  • 9