I'm working on a many-to-one RNN with variable length, and trying to extract the relevant output from the last non-padded unit. In particular, let's say I'm using a batch size of 3, with 16 hidden units, and the max-length in this batch is 6, so the "output" is of size 3-6-16, for example:
tensor([[[-0.0650, -0.0712, -0.1024, -0.0544, 0.1491, -0.0390, -0.0559,
0.0675, 0.0298, -0.1015, 0.1557, -0.0176, -0.0348, 0.0585,
-0.0673, -0.1097],
[-0.0515, -0.1113, -0.1194, -0.0942, 0.2349, -0.0736, -0.0474,
0.1339, 0.0362, -0.1515, 0.2344, 0.0125, -0.0403, 0.0810,
-0.0817, -0.1919],
[-0.0588, -0.1078, -0.1033, -0.0454, 0.2273, -0.0773, -0.0720,
0.1857, 0.0817, -0.1805, 0.2329, 0.0146, -0.0437, 0.0905,
-0.1100, -0.2268],
[-0.0926, -0.1068, -0.1196, -0.0647, 0.2364, -0.1125, -0.0422,
0.1958, 0.0841, -0.2176, 0.2558, 0.0397, -0.0856, 0.0867,
-0.0862, -0.2232],
[-0.0692, -0.1501, -0.1342, -0.0725, 0.2564, -0.1084, -0.0767,
0.2042, 0.1136, -0.2037, 0.2773, -0.0236, -0.0786, 0.0889,
-0.1053, -0.2285],
[-0.0783, -0.1734, -0.1472, -0.1053, 0.2649, -0.0928, -0.0306,
0.1727, 0.0962, -0.2102, 0.3104, -0.0102, -0.0566, 0.0878,
-0.1159, -0.2546]],
[[-0.0446, -0.0554, -0.1074, -0.0609, 0.1355, -0.0883, -0.0113,
0.0571, -0.0149, -0.0931, 0.1628, 0.0269, -0.0660, 0.0590,
-0.0089, -0.1344],
[-0.0122, -0.1504, -0.1654, -0.0200, 0.2308, -0.0633, 0.0030,
0.1265, 0.0164, -0.1145, 0.1606, 0.0188, -0.0546, 0.0469,
-0.0717, -0.2503],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]],
[[-0.0213, -0.0946, -0.0691, -0.0673, 0.1766, -0.1065, -0.0824,
0.0633, -0.0019, -0.1095, 0.2158, -0.0181, -0.0451, 0.0849,
-0.0387, -0.1313],
[ 0.0065, -0.1240, -0.1519, -0.0734, 0.2415, -0.1553, -0.0661,
0.1188, 0.0530, -0.1496, 0.2765, -0.0181, -0.0736, 0.1144,
-0.0671, -0.2182],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000]]], device='cuda:0', grad_fn=<TransposeBackward0>)
I want to extract this (the last non-padded row in each item):
tensor([[-0.0783, -0.1734, -0.1472, -0.1053, 0.2649, -0.0928, -0.0306, 0.1727,
0.0962, -0.2102, 0.3104, -0.0102, -0.0566, 0.0878, -0.1159, -0.2546],
[-0.0122, -0.1504, -0.1654, -0.0200, 0.2308, -0.0633, 0.0030, 0.1265,
0.0164, -0.1145, 0.1606, 0.0188, -0.0546, 0.0469, -0.0717, -0.2503],
[ 0.0065, -0.1240, -0.1519, -0.0734, 0.2415, -0.1553, -0.0661, 0.1188,
0.0530, -0.1496, 0.2765, -0.0181, -0.0736, 0.1144, -0.0671, -0.2182]])
with proper gradient info so the RNN can continue training.
What I've tried:
selectors = torch.tensor([5, 1, 1])
torch.index_select(sequence, 1, selectors)
but this returns 3 rows (the 5th, 1st, 1st (0-indexed)) for each item, while I wanted the 5th for the first item, 1st for the second item, and so forth.
I've also tried
arr = torch.zeros(len(sequence), 16, dtype=torch.float32)
for i in range(len(sequence)):
arr[i, :] = sequence[i, selectors[i], :]
This gives me something very close to what I want, but I think there are some issues with the gradient info..