In my network, I want to calculate the forward pass and backward pass of my network both in the forward pass.
For this, I have to manually define all the backward pass methods of the forward pass layers.
For the activation functions, that's easy. And also for the linear and conv layers, it worked well. But I'm really struggling with BatchNorm. As the BatchNorm paper only discusses the 1D case:
So far, my implementation looks like this:
def backward_batchnorm2d(input, output, grad_output, layer):
gamma = layer.weight
beta = layer.bias
avg = layer.running_mean
var = layer.running_var
eps = layer.eps
B = input.shape[0]
# avg, var, gamma and beta are of shape [channel_size]
# while input, output, grad_output are of shape [batch_size, channel_size, w, h]
# for my calculations I have to reshape avg, var, gamma and beta to [batch_size, channel_size, w, h] by repeating the channel values over the whole image and batches
dL_dxi_hat = grad_output * gamma
dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B # dL_dxi_hat / sqrt()
dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True)
dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
return dL_dxi, dL_dgamma, dL_dbeta
When I check my gradients with torch.autograd.grad() I notice that dL_dgamma
and dL_dbeta
are correct, but dL_dxi
is incorrect, (by a lot). But I can't find my mistake. Where is my mistake?
For reference, here is the definition of BatchNorm:
And here are the formulas for the derivatives for the 1D case: