2

I am trying to get gradient from sum of some indexes of an array using bincount. However, pytorch does not implement the gradient. This can be implemented by a loop and torch.sum but it is too slow. Is it possible to do this efficiently in pytorch (maybe einsum or index_add)? Of course, we can loop over indexes and add one by one, however that would increase the computational graph size significantly and is very low performance.

import torch
from torch import autograd
import numpy as np
tt = lambda x, grad=True: torch.tensor(x, requires_grad=grad)    
inds = tt([1, 5, 7, 1], False).long()
y = tt(np.arange(4) + 0.1).float()
sum_y_section = torch.bincount(inds, y * y, minlength=8)
#sum_y_section = torch.sum(y * y)
grad = autograd.grad(sum_y_section, y, create_graph=True, allow_unused=False)
print("sum_y_section", sum_y_section)
print("grad", grad)
Roy
  • 65
  • 2
  • 15
  • 40

3 Answers3

3

We can use a new feature in Pytorch V1.11 called scatter_reduce.

bincount = lambda inds, arr: torch.scatter_reduce(arr, 0, inds, reduce="sum")
Roy
  • 65
  • 2
  • 15
  • 40
1

I’d try to use a hook to manipulate the gradient in a custom way

emely_pi
  • 517
  • 5
  • 23
1

torch.scatter_reduce has a src positional argument in Pytorch 1.13. This simple example demonstrates equivalence between bincount and scatter_reduce:

num_bins = 8
bins = torch.zeros(num_bins)

#generate indices and weights
num_indices = 100
indices = torch.randint(num_bins, size=(num_indices,))
weights = torch.rand(num_indices)

# Counting Indices

# with torch.bincount
counts1 = torch.bincount(indices, minlength=num_bins)

# with torch.scatter_reduce
counts2 = bins.scatter_reduce(0, indices, torch.ones(num_indices), reduce = 'sum')
print(counts1)
print(counts2)

# Binning Weights

# with torch.bincount
binned_wts1 =  indices.bincount(weights, minlength=num_bins)

# with torch.scatter_reduce
binned_wts2 = bins.scatter_reduce(0, indices, weights, reduce='sum')

print(binned_wts1)
print(binned_wts2)

lw123
  • 11
  • 1