3

I have a large vector that I would like to update. I'll update it by adding an offset to specific elements in the vector. I specify a vector of indices that I want to update (call the index vector ix), and for each index, I specify a value that I want to add to that element (call the value vector vals). If all entries of the index vector are unique, then the following code suffices:

vec = torch.zeros(4, dtype=torch.float)
ix = torch.tensor([0,2], dtype=torch.long)
vals = torch.tensor([0.2, 0.5], dtype=torch.float)
vec[ix] += vals

However, this does not work if there are repeated indices in ix. A naive approach for the case of repeated indices is as follows:

for i in range(len(ix)):
    vec[ix[i]] += vals[i]

But this doesn't scale well - it is very slow when ix is large. Are there any faster ways to do this? If there were a fast way to sum all entries of vals that have the same index in ix, then the solution should be easy.

Update:
I found one solution that works pretty well, described below. I'd still love feedback for better solutions.

# get unique indices
ix_unique = torch.unique(ix)

# for each unique index, get sum of all vals with that index
vals_unique = torch.stack([
    torch.sum(torch.where(ix==i, vals, torch.zeros_like(vals))) 
    for i in ix_unique
])

# update vec
vec[ix_unique] += vals_unique
zillo
  • 31
  • 3
  • You can write down your own answer, that indicates to others that there is a solution! – MBT Nov 15 '18 at 18:17

2 Answers2

0

For cases where you want to allow multiple updates to the same ix indices, there also exists a library called pytorch_scatter. In such cases, having e.g. 3 identical indicies within ix would then result in 3*val being added to that index.

Max
  • 485
  • 4
  • 13
0

torch.index_add()

import torch

vec = torch.zeros(4, dtype=torch.float)
ix = torch.tensor([0,0,2], dtype=torch.long)
vals = torch.tensor([0.2,0.1,0.5], dtype=torch.float)
torch.index_add(vec, 0, ix, vals)

and you will get

tensor([0.3000, 0.0000, 0.5000, 0.0000])

Reference: official doc

ZhaoYi
  • 211
  • 2
  • 8