3

In my network, I want to calculate the forward pass and backward pass of my network both in the forward pass. For this, I have to manually define all the backward pass methods of the forward pass layers.
For the activation functions, that's easy. And also for the linear and conv layers, it worked well. But I'm really struggling with BatchNorm. As the BatchNorm paper only discusses the 1D case: So far, my implementation looks like this:

def backward_batchnorm2d(input, output, grad_output, layer):
    gamma = layer.weight
    beta = layer.bias
    avg = layer.running_mean
    var = layer.running_var
    eps = layer.eps
    B = input.shape[0]

    # avg, var, gamma and beta are of shape [channel_size]
    # while input, output, grad_output are of shape [batch_size, channel_size, w, h]
    # for my calculations I have to reshape avg, var, gamma and beta to [batch_size, channel_size, w, h] by repeating the channel values over the whole image and batches

    dL_dxi_hat = grad_output * gamma
    dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True)
    dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
    dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B # dL_dxi_hat / sqrt()
    dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True)
    dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
    return dL_dxi, dL_dgamma, dL_dbeta

When I check my gradients with torch.autograd.grad() I notice that dL_dgamma and dL_dbeta are correct, but dL_dxi is incorrect, (by a lot). But I can't find my mistake. Where is my mistake?

For reference, here is the definition of BatchNorm:

enter image description here

And here are the formulas for the derivatives for the 1D case:enter image description here

CuCaRot
  • 1,208
  • 7
  • 23
sebko_iic
  • 340
  • 2
  • 16
  • You can follow [this](https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html). – swag2198 Jun 14 '21 at 15:02
  • Yes I know this article but as the author says himself: This is the naive implementation of the backward pass. – sebko_iic Jun 15 '21 at 10:11
  • Why are you **explicitly** trying to compute the gradients? why not using pytorch's [autograd](https://pytorch.org/docs/stable/autograd.html?highlight=autograd#module-torch.autograd) mechanism? For instance [here](https://stackoverflow.com/a/67471097/1714410) `autograd` is used to define a loss on the _gradients_ of the forward pass. – Shai Jun 22 '21 at 11:57

1 Answers1

7
def backward_batchnorm2d(input, output, grad_output, layer):
    gamma = layer.weight
    gamma = gamma.view(1,-1,1,1) # edit
    # beta = layer.bias
    # avg = layer.running_mean
    # var = layer.running_var
    eps = layer.eps
    B = input.shape[0] * input.shape[2] * input.shape[3] # edit

    # add new
    mean = input.mean(dim = (0,2,3), keepdim = True)
    variance = input.var(dim = (0,2,3), unbiased=False, keepdim = True)
    x_hat = (input - mean)/(torch.sqrt(variance + eps))
    
    dL_dxi_hat = grad_output * gamma
    # dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0, 2, 3), keepdim=True) 
    # dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0, 2, 3), keepdim=True) / B
    dL_dvar = (-0.5 * dL_dxi_hat * (input - mean)).sum((0, 2, 3), keepdim=True)  * ((variance + eps) ** -1.5) # edit
    dL_davg = (-1.0 / torch.sqrt(variance + eps) * dL_dxi_hat).sum((0, 2, 3), keepdim=True) + (dL_dvar * (-2.0 * (input - mean)).sum((0, 2, 3), keepdim=True) / B) #edit
    
    dL_dxi = (dL_dxi_hat / torch.sqrt(variance + eps)) + (2.0 * dL_dvar * (input - mean) / B) + (dL_davg / B) # dL_dxi_hat / sqrt()
    # dL_dgamma = (grad_output * output).sum((0, 2, 3), keepdim=True) 
    dL_dgamma = (grad_output * x_hat).sum((0, 2, 3), keepdim=True) # edit
    dL_dbeta = (grad_output).sum((0, 2, 3), keepdim=True)
    return dL_dxi, dL_dgamma, dL_dbeta
  1. Because you didn't upload your forward snipcode, so if your gamma has the shape size is 1, you need to reshape it to [1,gamma.shape[0],1,1].
  2. The formula follows 1D where the scale factor is the sum of the batch size. However, in 2D, the summation should between 3 dimensions, so B = input.shape[0] * input.shape[2] * input.shape[3].
  3. The running_mean and running_var only use in test/inference mode, we don't use them in training (you can find it in the paper). The mean and variance you need are computed from the input, you can store the mean, variance and x_hat = (x-mean)/sqrt(variance + eps) into your object layer or re-compute as I did in the code above # add new. Then replace them with the formula of dL_dvar, dL_davg, dL_dxi.
  4. your dL_dgamma should be incorrect since you multiplied the gradient of output by itself, it should be modified to grad_output * x_hat.
CuCaRot
  • 1,208
  • 7
  • 23
  • My forward code is just using the standard nn.Batchnorm2d(). I will test your modifications. Thanks for pointing out number 3, I was unsure about that. About number 4: It's weird I know but somehow although ```output``` is coming directly from BatchNorm2d, it seems to represent ```x_hat```. I'm quite sure of this because the condition stated in the paper: \sum{\^{x}}=0 is true for the ```output``` but not for ```x_hat = (input - mean)/(torch.sqrt(variance + eps))```. But I used the running_mean and running_var which I now know is wrong, but still it wasn't even close to 0. – sebko_iic Jun 22 '21 at 10:10
  • the `output` variable should represent for `y = gamma * x_hat + beta`, then you will see that number 4 is correct. If your output is different, I think maybe you got some mistakes in your forward code. – CuCaRot Jun 22 '21 at 11:05
  • I agree, but the output is coming directly from the torch implemented layer – sebko_iic Jun 22 '21 at 12:08
  • 1
    Yes, so you don't need it to compute anything, I check my answer with pytorch 'nn.BatchNorm2d', seem like it's correct – CuCaRot Jun 23 '21 at 00:26