0

This is a question about PyTorch autograd.grad and the backward function specifically.

I have two tensors a, b which are optimized over (i.e. require gradients).

I define loss1, loss2 = f(a,b), g(a,b). Although these are two separate functions f and g, for computational efficiency reasons, I have to compute both of them together as fg(a,b) which returns a tuple (loss1, loss2).

I need to use opt_a and opt_b (optimizers) to step a and b with the following gradients:

a.grad should equal d (loss1)/ d (a)

b.grad should equal d (loss2)/ d (b)

How can I achieve these gradients? I know I can run autograd.grad(loss1, a) and autograd.grad(loss2, b) to get the true gradients and set them to *.grad manually, but I want to use the backward method on the loss1 and loss2.

I want to use the backward method because it is concise code when in my case, a and b are actually the lists of parameters of two neural networks (I don't want to be manually setting param.grad = ... for param in model1.parameters()).

Is there a clean way to do this with .backward()?


My Attempt

I have tried multiple ordered variants of the following (but none of them work because the gradients add for one variable at least):

loss1, loss2 = fg(...)

opt1.zero_grad()
loss1.backward(retain_graph=True)
opt1.step()

opt2.zero_grad()
loss2.backward()
opt2.step()

Different orders of this result in either an accumulation of gradients (d (loss1+loss2)/ d (a)) or they result in one optimizer stepping the value of a, and then I can't run b.backward() because of an in-place operation change.

Some testing code I used: enter image description here

akarshkumar0101
  • 367
  • 2
  • 10
  • 1
    Did you check `a.grad` after both steps have been performed? If so, that isn't quite correct. `opt1.step()` updates the parameters based on `a.grad` at that point, but it doesn't wipe them out and further changes to them won't have any effect without calling `opt1.step()` again, so the final value in `a.grad` isn't representative of the updates that have been performed. Could you verify that `a.grad` is correct after `opt1.step()`, or even better, could you check that `a` was updated correctly at the end? – Michael Jungo May 13 '20 at 04:52
  • @MichaelJungo a.grad would be correctly updated definitely if it is the first step call. The problem is that since opt_a.step() changes the value of a (not a.grad), the call loss2.backward() fails because it relied on a previous version of a (its values). This results in the in-place change of version error. – akarshkumar0101 May 13 '20 at 05:21
  • 1
    Oh, I see. Then I can't really see a version with `.backward()` without detaching `a` at some point, but that does not seem to be an option, since that would require `f` and `g` to be independent. I must say that I'm a little puzzled as to why `a` is used as a learnable parameter with `b` being constant and vice-versa, yet they share most of the computations. – Michael Jungo May 13 '20 at 06:26
  • I'm running a special adversarial setup with two neural net models. So model1(input), model2(input) are both used to compute loss1 and loss2. But anyway, it feels ugly :/, but I ended up just manually setting up gradients like this: grad_o = torch.autograd.grad(loss_o, model_o.parameters(), retain_graph=True) grad_d = torch.autograd.grad(loss_d, model_d.parameters()) for i, params in enumerate(model_o.parameters()): params.grad = grad_o[i] for i, params in enumerate(model_d.parameters()): params.grad = grad_d[i] – akarshkumar0101 May 13 '20 at 07:37
  • 1
    Okay, it might be your best bet then. But I still believe that the simplified example you provided was not entirely realistic. The snippet in your last comment is what I would have imagined, for which a solution would be `loss1 = f(output1, output2.detach())` and `loss2 = g(output1.detach(), output2)`, where `f` and `g` are the specific loss calculations, which are probably insignificant in terms of computational overhead compared to the models. Especially since you can avoid having to back propagate through the whole graph twice, and do it only once per model (two subgraphs). – Michael Jungo May 13 '20 at 08:15

0 Answers0