See the code snippet:
import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)
The output is tensor([0.])
, but
import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
y = x
else:
y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)
The output is None
.
I'm confused that why the output of torch.where
is tensor([0.])
?
update
import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b
(a[0, 0] * a[0, 1]).backward()
print(b.grad)
The output is tensor([2., 0.])
. (a[0, 0] * a[0, 1])
is not in any way related to b[1]
, but the gradient of b[1]
is 0
not None
.