I'm not sure how to select the last hidden/cell states in a bidirectional LSTM in Pytorch.
output, (hn, cn) = bi_lstm(input, (h0, c0))
How can I use output
, hn
and cn
in order to extract the last forward and backward hidden states?
In the case of backward LSTM, I want to extract the hidden state I get after processing the entire sequence backwards.