0

I'm working on a many-to-one RNN with variable length, and trying to extract the relevant output from the last non-padded unit. In particular, let's say I'm using a batch size of 3, with 16 hidden units, and the max-length in this batch is 6, so the "output" is of size 3-6-16, for example:

tensor([[[-0.0650, -0.0712, -0.1024, -0.0544,  0.1491, -0.0390, -0.0559,
           0.0675,  0.0298, -0.1015,  0.1557, -0.0176, -0.0348,  0.0585,
          -0.0673, -0.1097],
         [-0.0515, -0.1113, -0.1194, -0.0942,  0.2349, -0.0736, -0.0474,
           0.1339,  0.0362, -0.1515,  0.2344,  0.0125, -0.0403,  0.0810,
          -0.0817, -0.1919],
         [-0.0588, -0.1078, -0.1033, -0.0454,  0.2273, -0.0773, -0.0720,
           0.1857,  0.0817, -0.1805,  0.2329,  0.0146, -0.0437,  0.0905,
          -0.1100, -0.2268],
         [-0.0926, -0.1068, -0.1196, -0.0647,  0.2364, -0.1125, -0.0422,
           0.1958,  0.0841, -0.2176,  0.2558,  0.0397, -0.0856,  0.0867,
          -0.0862, -0.2232],
         [-0.0692, -0.1501, -0.1342, -0.0725,  0.2564, -0.1084, -0.0767,
           0.2042,  0.1136, -0.2037,  0.2773, -0.0236, -0.0786,  0.0889,
          -0.1053, -0.2285],
         [-0.0783, -0.1734, -0.1472, -0.1053,  0.2649, -0.0928, -0.0306,
           0.1727,  0.0962, -0.2102,  0.3104, -0.0102, -0.0566,  0.0878,
          -0.1159, -0.2546]],

        [[-0.0446, -0.0554, -0.1074, -0.0609,  0.1355, -0.0883, -0.0113,
           0.0571, -0.0149, -0.0931,  0.1628,  0.0269, -0.0660,  0.0590,
          -0.0089, -0.1344],
         [-0.0122, -0.1504, -0.1654, -0.0200,  0.2308, -0.0633,  0.0030,
           0.1265,  0.0164, -0.1145,  0.1606,  0.0188, -0.0546,  0.0469,
          -0.0717, -0.2503],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000]],

        [[-0.0213, -0.0946, -0.0691, -0.0673,  0.1766, -0.1065, -0.0824,
           0.0633, -0.0019, -0.1095,  0.2158, -0.0181, -0.0451,  0.0849,
          -0.0387, -0.1313],
         [ 0.0065, -0.1240, -0.1519, -0.0734,  0.2415, -0.1553, -0.0661,
           0.1188,  0.0530, -0.1496,  0.2765, -0.0181, -0.0736,  0.1144,
          -0.0671, -0.2182],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000]]], device='cuda:0', grad_fn=<TransposeBackward0>)

I want to extract this (the last non-padded row in each item):

tensor([[-0.0783, -0.1734, -0.1472, -0.1053,  0.2649, -0.0928, -0.0306,  0.1727,
          0.0962, -0.2102,  0.3104, -0.0102, -0.0566,  0.0878, -0.1159, -0.2546],
        [-0.0122, -0.1504, -0.1654, -0.0200,  0.2308, -0.0633,  0.0030,  0.1265,
          0.0164, -0.1145,  0.1606,  0.0188, -0.0546,  0.0469, -0.0717, -0.2503],
        [ 0.0065, -0.1240, -0.1519, -0.0734,  0.2415, -0.1553, -0.0661,  0.1188,
          0.0530, -0.1496,  0.2765, -0.0181, -0.0736,  0.1144, -0.0671, -0.2182]])

with proper gradient info so the RNN can continue training.

What I've tried:

selectors = torch.tensor([5, 1, 1])
torch.index_select(sequence, 1, selectors)

but this returns 3 rows (the 5th, 1st, 1st (0-indexed)) for each item, while I wanted the 5th for the first item, 1st for the second item, and so forth.

I've also tried

arr = torch.zeros(len(sequence), 16, dtype=torch.float32)
    for i in range(len(sequence)):
        arr[i, :] = sequence[i, selectors[i], :]

This gives me something very close to what I want, but I think there are some issues with the gradient info..

Michael Xu
  • 188
  • 3
  • 11

2 Answers2

0

Found my answer here...

Get each sequence's last item from packed sequence

Adapted to my problem,

sequences[torch.arange(len(sequences)), selectors]
Michael Xu
  • 188
  • 3
  • 11
0

You can use torch.gather:

sequences.gather(dim=1, index=torch.tensor([5,1,1])[:,None, None].expand(3, 1, 16))
Shai
  • 111,146
  • 38
  • 238
  • 371