2

I am trying to put a packed and padded sequence through a GRU, and retrieve the output of the last item of each sequence. Of course I don't mean the -1 item, but the actual last, not-padded item. We know the lengths of the sequences in advance, so it should be as easy as to extract for each sequence the length-1 item.

I tried the following

import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# Data
input = torch.Tensor([[[0., 0., 0.],
                       [1., 0., 1.],
                       [1., 1., 0.],
                       [1., 0., 1.],
                       [1., 0., 1.],
                       [1., 1., 0.]],

                      [[1., 1., 0.],
                       [0., 1., 0.],
                       [0., 0., 0.],
                       [0., 1., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]],

                      [[0., 0., 0.],
                       [1., 0., 0.],
                       [1., 1., 1.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]],

                      [[1., 1., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.],
                       [0., 0., 0.]]])

lengths = [6, 4, 3, 1]
p = pack_padded_sequence(input, lengths, batch_first=True)

# Forward
gru = torch.nn.GRU(3, 12, batch_first=True)
packed_output, gru_h = gru(p)

# Unpack
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)

last_seq_idxs = torch.LongTensor([x-1 for x in input_sizes])

last_seq_items = torch.index_select(output, 1, last_seq_idxs) 

print(last_seq_items.size())
# torch.Size([4, 4, 12])

But the shape is not what I expect. I had expected to get 4x12, i.e. last item of each individual sequence x hidden.`

I could loop through the whole thing, and build a new tensor containing the items I need, but I was hoping for a built-in approach that took advantage of some smart math. I fear that manually looping and building, will result in very poor performance.

Bram Vanroy
  • 27,032
  • 24
  • 137
  • 239

2 Answers2

4

Instead of last two operations last_seq_idxs and last_seq_items you could just do last_seq_items=output[torch.arange(4), input_sizes-1].

I don't think index_select is doing the right thing. It will select the whole batch at the index you passed and therefore your output size is [4,4,12].

Umang Gupta
  • 15,022
  • 6
  • 48
  • 66
1

A more verbose alternative to Umang Gupta's answer:

# ...
output, input_sizes = pad_packed_sequence(packed_output, batch_first=True)
# One per sequence, with its last actual node extracted, and unsqueezed
last_seq = [output[e, i-1, :].unsqueeze(0) for e, i in enumerate(input_sizes)]
# Merge them together all sequences together to get batch
last_seq = torch.cat(last_seq, dim=0)
Bram Vanroy
  • 27,032
  • 24
  • 137
  • 239
  • I think this would be slower, isn't it? – Umang Gupta Mar 28 '19 at 19:08
  • 1
    @UmangGupta Oh yes, I definitely think so as yours is just a slice whereas my approach requires an iteration and a concatenation. I posted mine for illustrative purposes. But yours is the one that should be used. – Bram Vanroy Mar 28 '19 at 20:15