17

Hi I have a question about how to collect the correct result from a BI-LSTM module’s output.

Suppose I have a 10-length sequence feeding into a single-layer LSTM module with 100 hidden units:

lstm = nn.LSTM(5, 100, 1, bidirectional=True)

output will be of shape:

[10 (seq_length), 1 (batch),  200 (num_directions * hidden_size)]
# or according to the doc, can be viewed as
[10 (seq_length), 1 (batch),  2 (num_directions), 100 (hidden_size)]

If I want to get the 3rd (1-index) input’s output at both directions (two 100-dim vectors), how can I do it correctly?

I know output[2, 0] will give me a 200-dim vector. Does this 200 dim vector represent the output of 3rd input at both directions?

A thing bothering me is that when do reverse feeding, the 3rd (1-index) output vector is calculated from the 8th(1-index) input, right?

Will pytorch automatically take care of this and group output considering direction?

Thanks!

MBT
  • 21,733
  • 19
  • 84
  • 102
Crt Tax
  • 378
  • 1
  • 2
  • 11

3 Answers3

9

Yes, when using a BiLSTM the hidden states of the directions are just concatenated (the second part after the middle is the hidden state for feeding in the reversed sequence).
So splitting up in the middle works just fine.

As reshaping works from the right to the left dimensions you won't have any problems in separating the two directions.


Here is a small example:

# so these are your original hidden states for each direction
# in this case hidden size is 5, but this works for any size
direction_one_out = torch.tensor(range(5))
direction_two_out = torch.tensor(list(reversed(range(5))))
print('Direction one:')
print(direction_one_out)
print('Direction two:')
print(direction_two_out)

# before outputting they will be concatinated 
# I'm adding here batch dimension and sequence length, in this case seq length is 1
hidden = torch.cat((direction_one_out, direction_two_out), dim=0).view(1, 1, -1)
print('\nYour hidden output:')
print(hidden, hidden.shape)

# trivial case, reshaping for one hidden state
hidden_reshaped = hidden.view(1, 1, 2, -1)
print('\nReshaped:')
print(hidden_reshaped, hidden_reshaped.shape)

# This works as well for abitrary sequence lengths as you can see here
# I've set sequence length here to 5, but this will work for any other value as well
print('\nThis also works for more multiple hidden states in a tensor:')
multi_hidden = hidden.expand(5, 1, 10)
print(multi_hidden, multi_hidden.shape)
print('Directions can be split up just like this:')
multi_hidden = multi_hidden.view(5, 1, 2, 5)
print(multi_hidden, multi_hidden.shape)

Output:

Direction one:
tensor([0, 1, 2, 3, 4])
Direction two:
tensor([4, 3, 2, 1, 0])

Your hidden output:
tensor([[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]]]) torch.Size([1, 1, 10])

Reshaped:
tensor([[[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]]]) torch.Size([1, 1, 2, 5])

This also works for more multiple hidden states in a tensor:
tensor([[[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],

        [[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],

        [[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],

        [[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]],

        [[0, 1, 2, 3, 4, 4, 3, 2, 1, 0]]]) torch.Size([5, 1, 10])
Directions can be split up just like this:
tensor([[[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]],


        [[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]],


        [[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]],


        [[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]],


        [[[0, 1, 2, 3, 4],
          [4, 3, 2, 1, 0]]]]) torch.Size([5, 1, 2, 5])

Hope this helps! :)

MBT
  • 21,733
  • 19
  • 84
  • 102
8

I know output[2, 0] will give me a 200-dim vector. Does this 200 dim vector represent the output of 3rd input at both directions?

The answer is YES.

The output tensor of LSTM module output is the concatenation of forward LSTM output and backward LSTM output at corresponding postion in input sequence. And h_n tensor is the output at last timestamp which is output of the lsat token in forward LSTM but the first token in backward LSTM.

In [1]: import torch
   ...: lstm = torch.nn.LSTM(input_size=5, hidden_size=3, bidirectional=True)
   ...: seq_len, batch, input_size, num_directions = 3, 1, 5, 2
   ...: in_data = torch.randint(10, (seq_len, batch, input_size)).float()
   ...: output, (h_n, c_n) = lstm(in_data)
   ...: 

In [2]: # output of shape (seq_len, batch, num_directions * hidden_size)
   ...: 
   ...: print(output)
   ...: 
tensor([[[ 0.0379,  0.0169,  0.2539,  0.2547,  0.0456, -0.1274]],

        [[ 0.7753,  0.0862, -0.0001,  0.3897,  0.0688, -0.0002]],

        [[ 0.7120,  0.2965, -0.3405,  0.0946,  0.0360, -0.0519]]],
       grad_fn=<CatBackward>)

In [3]: # h_n of shape (num_layers * num_directions, batch, hidden_size)
   ...: 
   ...: print(h_n)
   ...: 
tensor([[[ 0.7120,  0.2965, -0.3405]],

        [[ 0.2547,  0.0456, -0.1274]]], grad_fn=<ViewBackward>)

In [4]: output = output.view(seq_len, batch, num_directions, lstm.hidden_size)
   ...: print(output[-1, 0, 0])  # forward LSTM output of last token
   ...: print(output[0, 0, 1])  # backward LSTM output of first token
   ...: 
tensor([ 0.7120,  0.2965, -0.3405], grad_fn=<SelectBackward>)
tensor([ 0.2547,  0.0456, -0.1274], grad_fn=<SelectBackward>)

In [5]: h_n = h_n.view(lstm.num_layers, num_directions, batch, lstm.hidden_size)
   ...: print(h_n[0, 0, 0])  # h_n of forward LSTM
   ...: print(h_n[0, 1, 0])  # h_n of backward LSTM
   ...: 
tensor([ 0.7120,  0.2965, -0.3405], grad_fn=<SelectBackward>)
tensor([ 0.2547,  0.0456, -0.1274], grad_fn=<SelectBackward>)
Community
  • 1
  • 1
dd.
  • 826
  • 9
  • 13
  • hey dd., I have been trying to figure out your answer for a week already and I still can't, bit of help ? Your answer suggest to me that `output[-1, 0, :]` does not give the desired output, but one should take the middle points as you did (command 4). Can you please detail this ? confirm or something? My question is really, how to get the h_n, last output of the RNN? – Marine Galantin May 23 '21 at 21:48
  • 1
    in particular this seems to contradict the OP question to which you said yes ? https://towardsdatascience.com/understanding-bidirectional-rnn-in-pytorch-5bd25a5dd66 – Marine Galantin May 23 '21 at 21:54
0

I know output[2, 0] will give me a 200-dim vector. Does this 200 dim vector represent the output of 3rd input at both directions?

Yes, and no. It does represent the output of the input at index 2 in both the directions, but not the 3rd input in each direction. The 3rd input in the forward direction as the RNN sees it is at index 2, and the 3rd input in the reverse direction is at index 7.

Also as far as the LAST possible output in each direction is concerned, the following explanation is important. In the forward direction, the last output will be at index 9 (10th output), whereas the last output in the reverse direction will be at index 0 (10th output).

If you're viewing output as:

[10 (seq_length), 1 (batch),  2 (num_directions), 100 (hidden_size)]

Then the the last output in the forward direction will be output[9][0][0] and the last output in the reverse direction will be output[0][0][1].

I hope this clarifies things.

dhruvbird
  • 6,061
  • 6
  • 34
  • 39