4

The torch computation graph is composed of grad_fn. For the terminating node, the grad_fn object has an attribute called next_functions which is a tuple. I understand that using the first element (0th index) of the tuple, I can reconstruct the computation graph for gradients. But I was wondering what does the second element (1st index) of the tuple mean?

In one of the answers in PyTorch forums, it is said that:

The number is the input number to the next backward function, so can only be non-zero when a function has multiple differentiable outputs (there aren’t that many, but e.g. the RNN functions typically do).

But I am don't understand this statement. Can someone explain this with an example perhaps?

pecey
  • 651
  • 5
  • 13

1 Answers1

0

I am not an expert on Pytorch, but trying to answer your question, from the example:

a, b = torch.randn(2, requires_grad=True).unbind()
c = a+b
print(c.grad_fn.next_functions)
>>> ((<UnbindBackward object at 0x7f0ea438de80>, 0), (<UnbindBackward object at 0x7f0ea438de80>, 1))

Now, think of any Pytorch function as producing "a list of outputs", rather than "an output". So, if a function produces just one output (typical case); it produces a list that is equal to [output]. However, if the function produces several outputs, then it has a list of len > 1 outputs, eg: [output0, output1]. From that, I understand the tuple constituents as follows:
(grad_fn: the function object that resulted in this tensor, i: index of the tensor in the function outputs list.. which is typically zero since functions typically have one output)

Applying this understanding on the code, the unbind function has two outputs: a at index 0 of the outputs 'list', and b at index 1 of the outputs 'list'. The following points reason through the graph:

  • c.grad_fn is an AddBackward object, since c is the result of an addition operation. That addition operation has two branches in the computational graph (since it adds two operands, a and b). The first branch is for a and the second branch is for b(one branch for the operands, in order).
  • lets say output_tuple = c.grad_fun.next_functions, output_tuple has 2 elements, output_tuple[0] is a's (grad_fn, index of a in unbind outputs), and output_tuple[1] is for b's (grad_fn, index of b in unbind outputs).
  • a's grad_fn and b's grad_fn are the same (i.e. same exact object, a.grad_fn == b.grad_fn return True). However, a's tuple has the second element = 0 since it is the first output of the unbind function and b's tuple has the second element = 1 since it is the second output of the unbind function.

i.e. the second entry in the grad_fn tuple is the index of the tensor in the producing-function list of outputs.