5

See the code snippet:

import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)

The output is tensor([0.]), but

import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
    y = x
else:
    y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)

The output is None.

I'm confused that why the output of torch.where is tensor([0.])?

update

import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b

(a[0, 0] * a[0, 1]).backward()
print(b.grad)

The output is tensor([2., 0.]). (a[0, 0] * a[0, 1]) is not in any way related to b[1], but the gradient of b[1] is 0 not None.

gaussclb
  • 1,217
  • 3
  • 13
  • 26

1 Answers1

4

Tracking based AD, like pytorch, works by tracking. You can't track through things that are not function calls intercepted by the library. By using an if statement like this, there's no connection between x and y, whereas with where, x and y are linked in the expression tree.

Now, for the differences:

  • In the first snippet, 0 is the correct derivative of the function x ↦ x > 0 ? x : 2 at the point -1 (since the negative side is constant).
  • In the second snippet, as I said, x is not in any way related to y (in the else branch). Therefore, the derivative of y given x is undefined, which is represented as None.

(You can do such things even in Python, but that requires more sophisticated technology like source transformation. I don't thing it is possible with pytorch.)

phipsgabler
  • 20,535
  • 4
  • 40
  • 60
  • I guess the `0` gradient is equivalent to `None`, see my update. – gaussclb Apr 13 '20 at 17:41
  • No, it's still the same principle. In your new example, (a[0, 0] * a[0, 1]) is to be thought as a function of the whole of b. You backpropagate only through that part that is constant with repect to b[1]. I suggest you familiarize yourself a bit with the implementation details of AD systems in general, then this will become more easy to see. – phipsgabler Apr 13 '20 at 18:27
  • Is there a tutorial for the implementation details of AD systems, especially for the in-place operation, there exist many weird phenomenon. – gaussclb Apr 13 '20 at 18:44
  • I have a collection of references [here](https://github.com/TuringLang/IRTracker.jl/blob/master/background.md), but it's quite Julia-focused. And mutability is rarely treated, because it makes things difficult. – phipsgabler Apr 14 '20 at 06:24