2

I'm not sure how to select the last hidden/cell states in a bidirectional LSTM in Pytorch.

output, (hn, cn) = bi_lstm(input, (h0, c0))

How can I use output, hn and cn in order to extract the last forward and backward hidden states?

In the case of backward LSTM, I want to extract the hidden state I get after processing the entire sequence backwards.

miditower
  • 107
  • 2
  • 9

0 Answers0