1

I am performing multi-label image classification in PyTorch, and would like to compute the gradients of all outputs at ground truth labels for each input with respect to the input. I would preferably like to do this in a single backward pass for a batch of inputs.

For example:

inputs = torch.randn((4,3,224,224)) # Batch of 4 inputs
targets = torch.tensor([[1,0,1],[1,0,0],[0,0,1],[1,1,0]]) # Labels for each input
outputs = model(inputs) # 4 x 3 vector

Here, I want to find the gradient of:

  • output[0,0] and output[0,2] with respect to input[0]
  • output[1,0] with respect to input[1]
  • output[2,2] with respect to input[2]
  • output[3,0] and output[3,1] with respect to input[3]

Is there any way to do this in a single backward pass?


If my outputs were one-hot, i.e., there was only one label per class, I could use:

gt_classes = torch.where(targets==1)[1]
gather_outputs = torch.gather(outputs, 1, gt_classes.unsqueeze(-1))
grads = torch.autograd.grad(torch.unbind(gather_outputs), inputs)[0] # 4 x 3 x 224 x 224

This gives gradient of output[i,gt_classes[i]] with respect to input[i].

For my case, it looks like the is_grads_batched argument from torch.autograd.grad might be relevant, but it's not very clear how it is to be used.

GoodDeeds
  • 7,956
  • 5
  • 34
  • 61
  • Is it a batch size of `4` or do you mean `4d` input (`5d` with batch) i.e. `[b, 4, 3, 224, 224]`? – Wakeme UpNow Feb 15 '23 at 08:43
  • @WakemeUpNow The batch size is 4. – GoodDeeds Feb 15 '23 at 08:48
  • So in the forward pass, do each item in a batch interact with each other? If no then the gradient is `0`. Except for the corresponding index of course i.e. gradient of `output[0]` w.r.t. `input[0]` still exists, just not `output[0]` w.r.t `input[1]` – Wakeme UpNow Feb 15 '23 at 08:50
  • @WakemeUpNow No, they do not, and I realize I mislabeled the output indices in my question. Updating now. Thanks for pointing this out. – GoodDeeds Feb 15 '23 at 08:52

0 Answers0