the following code is similar to the code in the loop. the difference is that instead of sequentially indexing the array Z
,X
and Y
we are indexing them in parallel using the array i
B, M, N = 13, 7, 19
X = np.random.randint(100, size= [B,M,N])
Y = np.random.randint(M , size= [B,1])
Z = np.random.randint(100, size= [B,N])
i = np.arange(B)
Y = Y.ravel() # reducing array to rank-1, for easy indexing
Z[i] = X[i,Y[i],:]
this code can be further simplified as
>> Z[i] = X[i,Y[i],:]
>> Z[i] = X[i,Y[i]]
>> Z[i] = X[i,Y]
>> Z = X[i,Y]
pytorch equivalent code
B, M, N = 5, 7, 3
X = torch.randint(100, size= [B,M,N])
Y = torch.randint(M , size= [B,1])
Z = torch.randint(100, size= [B,N])
i = torch.arange(B)
Y = Y.ravel()
Z = X[i,Y]