The reason why it works w/o retain_graph=True
in your case is you have very simple graph that probably would have no internal intermediate buffers, in turn no buffers will be freed, so no need to use retain_graph=True
.
But everything is changing when adding one more extra computation to your graph:
Code:
x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2
y.backward(torch.ones(2, 2))
print('Backward 1st time w/o retain')
print('x.grad:', x.grad)
print('Backward 2nd time w/o retain')
try:
y.backward(torch.ones(2, 2))
except RuntimeError as err:
print(err)
print('x.grad:', x.grad)
Output:
Backward 1st time w/o retain
x.grad: tensor([[3., 3.],
[3., 3.]])
Backward 2nd time w/o retain
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
x.grad: tensor([[3., 3.],
[3., 3.]]).
In this case additional internal v.grad
will be computed, but torch
doesn't store intermediate values (intermediate gradients etc), and with retain_graph=False
v.grad
will be freed after first backward
.
So, if you want to backprop second time you need to specify retain_graph=True
to "keep" the graph.
Code:
x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2
y.backward(torch.ones(2, 2), retain_graph=True)
print('Backward 1st time w/ retain')
print('x.grad:', x.grad)
print('Backward 2nd time w/ retain')
try:
y.backward(torch.ones(2, 2))
except RuntimeError as err:
print(err)
print('x.grad:', x.grad)
Output:
Backward 1st time w/ retain
x.grad: tensor([[3., 3.],
[3., 3.]])
Backward 2nd time w/ retain
x.grad: tensor([[6., 6.],
[6., 6.]])