3

I'm trying to understand backpropagation in pytorch a bit better. I have a code snippet that successfully does backpropagation from the output d to the leaf variable a, but then if I add in a reshape step, the backpropagation no longer gives the input a gradient.

I know reshape is out-of-place, but I'm still not sure how to contextualize this.

Any thoughts?

Thanks.

#Works
a = torch.tensor([1.])
a.requires_grad = True
b = torch.tensor([1.])
c = torch.cat([a,b])
d = torch.sum(c)
d.backward()

print('a gradient is')
print(a.grad) #=> Tensor([1.])

#Doesn't work
a = torch.tensor([1.])
a.requires_grad = True
a = a.reshape(a.shape)
b = torch.tensor([1.])
c = torch.cat([a,b])
d = torch.sum(c)
d.backward()

print('a gradient is')
print(a.grad) #=> None
user49404
  • 732
  • 6
  • 22

1 Answers1

3

Edit:

Here is a detailed explanation of what's going on ("this isn't a bug per se, but it is definitely a source of confusion"): https://github.com/pytorch/pytorch/issues/19778

So one solution is to specifically ask to retain grad for now non-leaf a:

a = torch.tensor([1.])
a.requires_grad = True
a = a.reshape(a.shape)
a.retain_grad()
b = torch.tensor([1.])
c = torch.cat([a,b])
d = torch.sum(c)
d.backward()

Old answer:

If you move a.requires_grad = True after the reshape, it works:

a = torch.tensor([1.])
a = a.reshape(a.shape)
a.requires_grad = True
b = torch.tensor([1.])
c = torch.cat([a,b])
d = torch.sum(c)
d.backward()

Seems like a bug in PyTorch, because after this a.requires_grad is still true.

a = torch.tensor([1.])
a.requires_grad = True
a = a.reshape(a.shape)

This seems to be related to the fact the a is no longer a leaf in your "Doesn't work" example, but still a leaf in other cases (print a.is_leaf to check).

Sergii Dymchenko
  • 6,890
  • 1
  • 21
  • 46