1

I'm trying to calculate the gradient of the output of a simple neural network with respect to the inputs. The result looks fine when I don't use a BatchNorm layer. Once I do use it, the result doesn't seem to make much sense. Below is a short example to reproduce the effect.

class Net(nn.Module):
    def __init__(self, batch_norm):
        super().__init__()
        
        self.batch_norm = batch_norm
        self.act_fn = nn.Tanh()
        
        self.aff1 = nn.Linear(1, 10)
        self.aff2 = nn.Linear(10, 1)
        
        if batch_norm:
            self.bn = nn.BatchNorm1d(10, affine=False)  # False for simplicity
        
    def forward(self, x):
        x = self.aff1(x)
        x = self.act_fn(x)
        
        if self.batch_norm:
            x = self.bn(x)
            
        x = self.aff2(x)
        return x
x_vals = torch.linspace(0, 1, 100)
x_vals.requires_grad = True

fig, axs = plt.subplots(ncols=2, figsize=(16, 5))

for seed, bn, ax1 in zip([11, 4], [False, True], axs):  # different seeds for better illustration of effect
    torch.manual_seed(seed)
    net = Net(batch_norm=bn)

    net.train()
    pred = net(x_vals[:, None])
    pred_dx = torch.autograd.grad(pred.sum(), x_vals, create_graph=True)[0]

    # visualization
    ax2 = ax1.twinx()

    ax1.plot(x_vals.detach(), pred.detach())
    ax2.plot(x_vals.detach(), pred_dx.detach(), linestyle='--', color='orange')

    min_idx = torch.argmin((pred[1:]-pred[:-1])**2)
    ax2.axvline(x_vals[min_idx].detach(), color='gray', linestyle='dotted')
    ax2.axhline(0, color='gray', linestyle='dotted')
    ax1.set_title(('With' if bn else 'Without') + ' Batch Norm')
    
plt.show()

Visualization

The result also seems to be fine when I use evaluation mode. Unfortunately I can't just switch to eval() mode because the nature of my problem (PINNs) requires calculating gradient(s) during training.

I understand that during training the running mean and variance are updated. Maybe that has an impact? Can I still get the correct gradient somehow?

Thanks for your help!

mpr
  • 33
  • 5

1 Answers1

0

It is a bit confusing due to the use of the word '.train()' but :

  • net.train() put layers like batch normalization and dropout to an active state. So yes the running mean and variance will be updated but the gradient will still be computed.
  • with torch.no_grad() or setting your variables' require_grad to False will prevent any gradient computation

So if you need to "train" your batch normalization you won't really be able to get a gradient without being affected by the batch normalization. In case you don't need to train them just put your model in evaluation mode as the gradients are still computed