0

Is there a way to make Keras calculate mean and std separately on each batch in every BatchNormalization layer instead of training it (similar to how PyTorch does it)?

To elaborate more, Keras BatchNormalization layer holds four sets of weights [gamma, beta, mean, std] and updates them on every batch.

While PyTorch BatchNorm2d holds two sets of weights [gamma, beta] and the mean and std are calculated for every batch: https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d

I found a similar question here: How to set weights of the batch normalization layer? but it does not set the functionality as I described (similar to PyTorch).

Daniel Möller
  • 84,878
  • 18
  • 192
  • 214
Ilya
  • 730
  • 4
  • 16

1 Answers1

0

Keras "does" calculate the mean and std separately on each batch during training.

These values are not trained as the other two. They are just "accumulated". (The accumulated values are used only in prediction/inference)

That's why half of the parameters for a BatchNormalization layer appear as "not-trainable parameters" in the model summary. (So, be careful when assuming that PyTorch doesn't have these -- I don't know the PyTorch layer, if you know it "deeply" ignore these parentheses -- Since they are not "trainable" by backpropagation, maybe PyTorch actually has them, but doesn't treat them as "weights". Keras puts them as "weights" because they will be saved/loaded when you save/load the model, but they are certainly "not trainable" in the usual sense)

Why is it important to have these statistics accumulated? Because when you use your model for prediction (not for training), if your input data is not a big batch, or if your input data has mean and std very different from what the model was trained with, you may get wrong results. I believe it's not uncommon to want to predict a single item instead of a big batch in real situations.

But, if you want to remove them anyway, you can simply copy the code from the source and remove everything regarding the moving mean and std:

Source code for BatchNormalization

class BatchNormalization2(Layer):
    """
    Commented all things regarding moving statistics
    """

    @interfaces.legacy_batchnorm_support
    def __init__(self,
                 axis=-1,
                 momentum=0.99,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 #moving_mean_initializer='zeros',
                 #moving_variance_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super(BatchNormalization2, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.momentum = momentum
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        #self.moving_mean_initializer = initializers.get(moving_mean_initializer)
        #self.moving_variance_initializer = (initializers.get(moving_variance_initializer))
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)

    def build(self, input_shape):
        dim = input_shape[self.axis]
        if dim is None:
            raise ValueError('Axis ' + str(self.axis) + ' of '
                             'input tensor should have a defined dimension '
                             'but the layer received an input with shape ' +
                             str(input_shape) + '.')
        self.input_spec = InputSpec(ndim=len(input_shape),
                                    axes={self.axis: dim})
        shape = (dim,)

        if self.scale:
            self.gamma = self.add_weight(shape=shape,
                                         name='gamma',
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(shape=shape,
                                        name='beta',
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
        else:
            self.beta = None
        #self.moving_mean = self.add_weight(
        #    shape=shape,
        #    name='moving_mean',
        #    initializer=self.moving_mean_initializer,
        #    trainable=False)
        #self.moving_variance = self.add_weight(
        #    shape=shape,
        #    name='moving_variance',
        #    initializer=self.moving_variance_initializer,
        #    trainable=False)

        self.built = True

    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        #def normalize_inference():
        #    if needs_broadcasting:
        #        # In this case we must explicitly broadcast all parameters.
        #        broadcast_moving_mean = K.reshape(self.moving_mean,
        #                                          broadcast_shape)
        #        broadcast_moving_variance = K.reshape(self.moving_variance,
        #                                              broadcast_shape)
        #        if self.center:
        #            broadcast_beta = K.reshape(self.beta, broadcast_shape)
        #        else:
        #            broadcast_beta = None
        #        if self.scale:
        #            broadcast_gamma = K.reshape(self.gamma,
        #                                        broadcast_shape)
        #        else:
        #            broadcast_gamma = None
        #        return K.batch_normalization(
        #            inputs,
        #            broadcast_moving_mean,
        #            broadcast_moving_variance,
        #            broadcast_beta,
        #            broadcast_gamma,
        #            axis=self.axis,
        #            epsilon=self.epsilon)
        #    else:
        #        return K.batch_normalization(
        #            inputs,
        #            self.moving_mean,
        #            self.moving_variance,
        #            self.beta,
        #            self.gamma,
        #            axis=self.axis,
        #            epsilon=self.epsilon)

        ## If the learning phase is *static* and set to inference:
        #if training in {0, False}:
        #    return normalize_inference()

        # --- for SO answer, ignore the following original comment
        # If the learning is either dynamic, or set to training:
        normed_training, mean, variance = K.normalize_batch_in_training(
            inputs, self.gamma, self.beta, reduction_axes,
            epsilon=self.epsilon)

        if K.backend() != 'cntk':
            sample_size = K.prod([K.shape(inputs)[axis]
                                  for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs))
            if K.backend() == 'tensorflow' and sample_size.dtype != 'float32':
                sample_size = K.cast(sample_size, dtype='float32')

            # sample variance - unbiased estimator of population variance
            variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        #self.add_update([K.moving_average_update(self.moving_mean,
        #                                         mean,
        #                                         self.momentum),
        #                 K.moving_average_update(self.moving_variance,
        #                                         variance,
        #                                         self.momentum)],
        #                inputs)

        ## Pick the normalized form corresponding to the training phase.
        #return K.in_train_phase(normed_training,
        #                        normalize_inference,
        #                        training=training)
        return normed_training

    def get_config(self):
        config = {
            'axis': self.axis,
            'momentum': self.momentum,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            #'moving_mean_initializer':
            #    initializers.serialize(self.moving_mean_initializer),
            #'moving_variance_initializer':
            #    initializers.serialize(self.moving_variance_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super(BatchNormalization2, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape
Daniel Möller
  • 84,878
  • 18
  • 192
  • 214
  • You are right, I found in PyTorch the accumulated mean and var they are called running_mean and running_var. Thanks! – Ilya Oct 09 '19 at 14:16