2

Suppose I have a tensor Y that is (directly or indirectly) computed from a tensor X.

Normally when I apply torch.autograd.grad(Y, X, grad_outputs=torch.ones_like(Y)), I get a gradient mask that is of the same shape as X. This mask is actually a weighted sum of the gradients of the elements of Y w.r.t. X.

Is it possible to get a gradient mask of the same shape as Y instead, of which each element mask[i][j] is the sum of the gradients of Y[i][j] w.r.t. X?

This is equivalent to summing the Jacobian J(Y,X) over the dimensions of X instead of over the dimensions of Y.

>>> X = torch.eye(2)
>>> X.requires_grad_()
# X = [1 0]
#     [0 1]

>>> Y = torch.sum(X*X, dim=0)
# Y = [1, 1]

>>> torch.autograd.grad(Y, X, grad_outputs=torch.ones_like(Y), retain_graph=True)
(tensor([[2., 0.],
         [0., 2.]]),)

But instead, I want:

# [2, 2]

because torch.sum(torch.autograd.grad(Y[0],X) equals 2 and torch.sum(torch.autograd.grad(Y[1],X) equals 2 as well.

It would be easy to calculate the Jacobian of Y w.r.t X and just sum over the dimensions of X. However, this is unfeasible memory-wise, as the functions I work with are neural networks with huge inputs and outputs.

Calculating each gradient separately (as I did in the comments) is also very undesirable because this is too slow.

Jonas De Schouwer
  • 755
  • 1
  • 9
  • 15
  • How do you plan to use such resulting tensor? What is the mathematical reasoning behind it? – Ivan Aug 30 '21 at 16:29
  • I am writing an LRP (layer relevance propagation) implementation that requires such a gradient mask. Anyway, I don’t think this operation is too far-fetched. – Jonas De Schouwer Aug 30 '21 at 16:43

1 Answers1

1

If you run pytorch nightly, https://github.com/pytorch/pytorch/issues/10223 is partially implemented and should do what you want for most simple graphs. You could also try using the trick described at https://j-towns.github.io/2017/06/12/A-new-trick.html .

EDIT: It looks like https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html#torch.autograd.functional.jvp implements the backwards of backwards trick for you. So you could just do:

from torch.autograd.functional import jvp
X = torch.eye(2)                              
X.requires_grad_()                            
def build_Y(x):                               
    return torch.sum(x*x, dim=0)              
                                              
print(jvp(build_Y, X, torch.ones(X.shape))[1])
mlucy
  • 5,249
  • 1
  • 17
  • 21