1

How can I slice a 3D tensor using a 1D tensor? For instance, consider the following 2 tensors: t of size [Batch, Sequence, Dim]; and idx of size [Batch]. The values of idx are restricted to be integers between 0 and Sequence-1.

I need tensor idx to select the corresponding slices in the second dimension of tensor t. For example:

t = torch.arange(24).view(2,3,4)
>>> tensor([[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],

         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]])

idx = torch.tensor([0,2])
>>> tensor([0, 2])

Then the desired output is: tensor([[ 0, 1, 2, 3], [20, 21, 22, 23]]).

The following code solves the problem, however it's inefficient, as it involves one_hot, multiplication and sum operations.

one_hot_idx = nn.functional.one_hot(idx.long(), num_classes=t.shape[1]).unsqueeze(-1)
(t*one_hot_idx).sum(1)
core_not_dumped
  • 759
  • 2
  • 22
Oren
  • 171
  • 1
  • 8

1 Answers1

1

You can do it like this:

import torch
t = torch.arange(24).view(2, 3, 4)
idx = torch.tensor([0, 2])
print(t[range(len(idx)), idx])

Output:

tensor([[ 0,  1,  2,  3],
        [20, 21, 22, 23]])
Naphat Amundsen
  • 1,519
  • 1
  • 6
  • 17
  • Thanks, it does work, but I'm wondering if there's a solution that doesn't require a loop. – Oren Dec 22 '22 at 12:47
  • Are you thinking about the range? You could replace it with an torch arange but I don't believe it will scale any better for large tensors as it will need more memory allocation. – Naphat Amundsen Dec 22 '22 at 12:51
  • `torch.gather` does the trick in the [2D case](https://discuss.pytorch.org/t/select-specific-columns-of-each-row-in-a-torch-tensor/497/2). I was hoping it could be extended for 3D :| – Oren Dec 22 '22 at 20:45