Keras "does" calculate the mean and std separately on each batch during training.
These values are not trained as the other two. They are just "accumulated". (The accumulated values are used only in prediction/inference)
That's why half of the parameters for a BatchNormalization
layer appear as "not-trainable parameters" in the model summary. (So, be careful when assuming that PyTorch doesn't have these -- I don't know the PyTorch layer, if you know it "deeply" ignore these parentheses -- Since they are not "trainable" by backpropagation, maybe PyTorch actually has them, but doesn't treat them as "weights". Keras puts them as "weights" because they will be saved/loaded when you save/load the model, but they are certainly "not trainable" in the usual sense)
Why is it important to have these statistics accumulated? Because when you use your model for prediction (not for training), if your input data is not a big batch, or if your input data has mean and std very different from what the model was trained with, you may get wrong results. I believe it's not uncommon to want to predict a single item instead of a big batch in real situations.
But, if you want to remove them anyway, you can simply copy the code from the source and remove everything regarding the moving
mean and std:
Source code for BatchNormalization
class BatchNormalization2(Layer):
"""
Commented all things regarding moving statistics
"""
@interfaces.legacy_batchnorm_support
def __init__(self,
axis=-1,
momentum=0.99,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
#moving_mean_initializer='zeros',
#moving_variance_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
**kwargs):
super(BatchNormalization2, self).__init__(**kwargs)
self.supports_masking = True
self.axis = axis
self.momentum = momentum
self.epsilon = epsilon
self.center = center
self.scale = scale
self.beta_initializer = initializers.get(beta_initializer)
self.gamma_initializer = initializers.get(gamma_initializer)
#self.moving_mean_initializer = initializers.get(moving_mean_initializer)
#self.moving_variance_initializer = (initializers.get(moving_variance_initializer))
self.beta_regularizer = regularizers.get(beta_regularizer)
self.gamma_regularizer = regularizers.get(gamma_regularizer)
self.beta_constraint = constraints.get(beta_constraint)
self.gamma_constraint = constraints.get(gamma_constraint)
def build(self, input_shape):
dim = input_shape[self.axis]
if dim is None:
raise ValueError('Axis ' + str(self.axis) + ' of '
'input tensor should have a defined dimension '
'but the layer received an input with shape ' +
str(input_shape) + '.')
self.input_spec = InputSpec(ndim=len(input_shape),
axes={self.axis: dim})
shape = (dim,)
if self.scale:
self.gamma = self.add_weight(shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
else:
self.gamma = None
if self.center:
self.beta = self.add_weight(shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint)
else:
self.beta = None
#self.moving_mean = self.add_weight(
# shape=shape,
# name='moving_mean',
# initializer=self.moving_mean_initializer,
# trainable=False)
#self.moving_variance = self.add_weight(
# shape=shape,
# name='moving_variance',
# initializer=self.moving_variance_initializer,
# trainable=False)
self.built = True
def call(self, inputs, training=None):
input_shape = K.int_shape(inputs)
# Prepare broadcasting shape.
ndim = len(input_shape)
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis]
# Determines whether broadcasting is needed.
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
#def normalize_inference():
# if needs_broadcasting:
# # In this case we must explicitly broadcast all parameters.
# broadcast_moving_mean = K.reshape(self.moving_mean,
# broadcast_shape)
# broadcast_moving_variance = K.reshape(self.moving_variance,
# broadcast_shape)
# if self.center:
# broadcast_beta = K.reshape(self.beta, broadcast_shape)
# else:
# broadcast_beta = None
# if self.scale:
# broadcast_gamma = K.reshape(self.gamma,
# broadcast_shape)
# else:
# broadcast_gamma = None
# return K.batch_normalization(
# inputs,
# broadcast_moving_mean,
# broadcast_moving_variance,
# broadcast_beta,
# broadcast_gamma,
# axis=self.axis,
# epsilon=self.epsilon)
# else:
# return K.batch_normalization(
# inputs,
# self.moving_mean,
# self.moving_variance,
# self.beta,
# self.gamma,
# axis=self.axis,
# epsilon=self.epsilon)
## If the learning phase is *static* and set to inference:
#if training in {0, False}:
# return normalize_inference()
# --- for SO answer, ignore the following original comment
# If the learning is either dynamic, or set to training:
normed_training, mean, variance = K.normalize_batch_in_training(
inputs, self.gamma, self.beta, reduction_axes,
epsilon=self.epsilon)
if K.backend() != 'cntk':
sample_size = K.prod([K.shape(inputs)[axis]
for axis in reduction_axes])
sample_size = K.cast(sample_size, dtype=K.dtype(inputs))
if K.backend() == 'tensorflow' and sample_size.dtype != 'float32':
sample_size = K.cast(sample_size, dtype='float32')
# sample variance - unbiased estimator of population variance
variance *= sample_size / (sample_size - (1.0 + self.epsilon))
#self.add_update([K.moving_average_update(self.moving_mean,
# mean,
# self.momentum),
# K.moving_average_update(self.moving_variance,
# variance,
# self.momentum)],
# inputs)
## Pick the normalized form corresponding to the training phase.
#return K.in_train_phase(normed_training,
# normalize_inference,
# training=training)
return normed_training
def get_config(self):
config = {
'axis': self.axis,
'momentum': self.momentum,
'epsilon': self.epsilon,
'center': self.center,
'scale': self.scale,
'beta_initializer': initializers.serialize(self.beta_initializer),
'gamma_initializer': initializers.serialize(self.gamma_initializer),
#'moving_mean_initializer':
# initializers.serialize(self.moving_mean_initializer),
#'moving_variance_initializer':
# initializers.serialize(self.moving_variance_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint),
'gamma_constraint': constraints.serialize(self.gamma_constraint)
}
base_config = super(BatchNormalization2, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape