I am not an expert on Pytorch, but trying to answer your question, from the example:
a, b = torch.randn(2, requires_grad=True).unbind()
c = a+b
print(c.grad_fn.next_functions)
>>> ((<UnbindBackward object at 0x7f0ea438de80>, 0), (<UnbindBackward object at 0x7f0ea438de80>, 1))
Now, think of any Pytorch function as producing "a list of outputs", rather than "an output". So, if a function produces just one output (typical case); it produces a list that is equal to [output]
. However, if the function produces several outputs, then it has a list of len > 1 outputs, eg: [output0, output1].
From that, I understand the tuple constituents as follows:
(grad_fn: the function object that resulted in this tensor, i: index of the tensor in the function outputs list.. which is typically zero since functions typically have one output)
Applying this understanding on the code, the unbind
function has two outputs: a
at index 0
of the outputs 'list', and b
at index 1
of the outputs 'list'. The following points reason through the graph:
c.grad_fn
is an AddBackward
object, since c
is the result of an addition operation. That addition operation has two branches in the computational graph (since it adds two operands, a
and b
). The first branch is for a
and the second branch is for b
(one branch for the operands, in order).
- lets say
output_tuple = c.grad_fun.next_functions
, output_tuple
has 2 elements, output_tuple[0]
is a
's (grad_fn
, index of a
in unbind outputs), and output_tuple[1]
is for b
's (grad_fn
, index of b
in unbind outputs).
a
's grad_fn and b's grad_fn are the same (i.e. same exact object, a.grad_fn == b.grad_fn
return True). However, a
's tuple has the second element = 0 since it is the first output of the unbind function and b
's tuple has the second element = 1 since it is the second output of the unbind function.
i.e. the second entry in the grad_fn tuple is the index of the tensor in the producing-function list of outputs.