1

I have an B x M x N tensor, X, and I have and B x 1 tensor, Y, which corresponds to the index of tensor X at dimension=1 that I want to keep. What is the shorthand for this slice so that I can avoid a loop?

Essentially I want to do this:

Z = torch.zeros(B,N)

for i in range(B):
    Z[i] = X[i][Y[i]]
Megan Hardy
  • 397
  • 4
  • 12

2 Answers2

3

the following code is similar to the code in the loop. the difference is that instead of sequentially indexing the array Z,X and Y we are indexing them in parallel using the array i

B, M, N = 13, 7, 19

X = np.random.randint(100, size= [B,M,N])
Y = np.random.randint(M  , size= [B,1])
Z = np.random.randint(100, size= [B,N])

i = np.arange(B)
Y = Y.ravel()    # reducing array to rank-1, for easy indexing

Z[i] = X[i,Y[i],:]

this code can be further simplified as

>> Z[i] = X[i,Y[i],:]
>> Z[i] = X[i,Y[i]]
>> Z[i] = X[i,Y]
>> Z    = X[i,Y]

pytorch equivalent code

B, M, N = 5, 7, 3

X = torch.randint(100, size= [B,M,N])
Y = torch.randint(M  , size= [B,1])
Z = torch.randint(100, size= [B,N])

i = torch.arange(B)
Y = Y.ravel()

Z = X[i,Y]
hammi
  • 804
  • 5
  • 14
  • 1
    they say `Y` is of shape `(B, 1)` so you might want to change to `Y.view(-1)` or something similar in the very last expression. – Mustafa Aydın Jul 16 '21 at 18:55
1

The answer provided by @Hammad is short and perfect for the job. Here's an alternative solution if you're interested in using some less known Pytorch built-ins. We will use torch.gather (similarly you can achieve this with numpy.take).

The idea behind torch.gather is to construct a new tensor-based on two identically shaped tensors containing the indices (here ~ Y) and the values (here ~ X).

The operation performed is Z[i][j][k] = X[i][Y[i][j][k]][k].

Since X's shape is (B, M, N) and Y shape is (B, 1) we are looking to fill in the blanks inside Y such that Y's shape becomes (B, 1, N).

This can be achieved with some axis manipulation:

>>> Y.expand(-1, N)[:, None] # expand to dim=1 to N and unsqueeze dim=1

The actual call to torch.gather will be:

>>> X.gather(dim=1, index=Y.expand(-1, N)[:, None])

Which you can reshape to (B, N) by adding in [:, 0].


This function can be very effective in tricky scenarios...

Ivan
  • 34,531
  • 8
  • 55
  • 100