How can I slice a 3D tensor using a 1D tensor? For instance, consider the following 2 tensors: t
of size [Batch, Sequence, Dim]; and idx
of size [Batch]. The values of idx
are restricted to be integers between 0 and Sequence-1.
I need tensor idx
to select the corresponding slices in the second dimension of tensor t
. For example:
t = torch.arange(24).view(2,3,4)
>>> tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
idx = torch.tensor([0,2])
>>> tensor([0, 2])
Then the desired output is: tensor([[ 0, 1, 2, 3], [20, 21, 22, 23]])
.
The following code solves the problem, however it's inefficient, as it involves one_hot, multiplication and sum operations.
one_hot_idx = nn.functional.one_hot(idx.long(), num_classes=t.shape[1]).unsqueeze(-1)
(t*one_hot_idx).sum(1)