I am performing multi-label image classification in PyTorch, and would like to compute the gradients of all outputs at ground truth labels for each input with respect to the input. I would preferably like to do this in a single backward pass for a batch of inputs.
For example:
inputs = torch.randn((4,3,224,224)) # Batch of 4 inputs
targets = torch.tensor([[1,0,1],[1,0,0],[0,0,1],[1,1,0]]) # Labels for each input
outputs = model(inputs) # 4 x 3 vector
Here, I want to find the gradient of:
output[0,0]
andoutput[0,2]
with respect toinput[0]
output[1,0]
with respect toinput[1]
output[2,2]
with respect toinput[2]
output[3,0]
andoutput[3,1]
with respect toinput[3]
Is there any way to do this in a single backward pass?
If my outputs were one-hot, i.e., there was only one label per class, I could use:
gt_classes = torch.where(targets==1)[1]
gather_outputs = torch.gather(outputs, 1, gt_classes.unsqueeze(-1))
grads = torch.autograd.grad(torch.unbind(gather_outputs), inputs)[0] # 4 x 3 x 224 x 224
This gives gradient of output[i,gt_classes[i]]
with respect to input[i]
.
For my case, it looks like the is_grads_batched
argument from torch.autograd.grad
might be relevant, but it's not very clear how it is to be used.