I think it is. Let me make an example with code.
First we create the qvalues
tensor and we say we want to compute its gradients
qvalues = torch.rand((5, 5), requires_grad=True)
Now we create the tensor to index it and obtain a 5x2 tensor as its result (I think this is the same selection you wanted to perform with qvalues[range(5),[0,1,0,1,0]]
):
y = torch.LongTensor([1, 3])
new_qvalues = qvalues[:, y]
We see that the slice new_qvalues
of the original qvalues
will compute the gradient
print(new_qvalues.requires_grad) # True
Now we perform our mathematical operations. In this example code, I am doing the square of new_qvalues
because we know that its gradient (derivative) will be 2 * new_qvalues
.
qvalues_a = new_qvalues ** 2
Now, we have to compute the gradients of qvalues_a
. We set retain_graph=True
to store the .grad
of each tensor and avoid freeing the buffers on the backward pass.
qvalues_a.backward(torch.ones(new_qvalues.shape), retain_graph=True)
Now, we can go back to the original qvalues
and see if the gradients have been calculated
print(qvalues)
print(qvalues.grad)
# result of the print statemets
#tensor([[ 0.9677, 0.4303, 0.2036, 0.3870, 0.6085],
# [ 0.8876, 0.8695, 0.2028, 0.3283, 0.1560],
# [ 0.1764, 0.4718, 0.5418, 0.5167, 0.6200],
# [ 0.7610, 0.9322, 0.5584, 0.5589, 0.8901],
# [ 0.8146, 0.7296, 0.8036, 0.5277, 0.5754]])
#tensor([[ 0.0000, 0.8606, 0.0000, 0.7739, 0.0000],
# [ 0.0000, 1.7390, 0.0000, 0.6567, 0.0000],
# [ 0.0000, 0.9435, 0.0000, 1.0334, 0.0000],
# [ 0.0000, 1.8645, 0.0000, 1.1178, 0.0000],
# [ 0.0000, 1.4592, 0.0000, 1.0554, 0.0000]])
We can observe how the gradients have been computed only in the selected indexes. To be sure about it we create some fast test by comparing that the value of qvalues.grad
for the selected slice is equal to the derivate 2 * new_qvalues
.
assert torch.equal(qvalues.grad[:, y], 2 * new_qvalues)
And it does not throw any error, so I would assume that you can get the gradient of the slice.