Two nn.lstm(bidirectional=False) are used in pytorch to realize the function of a bidirectional lstm(nn.lstm(bidirectional=True)); Now the input is a stream sequence, and the result is found to be wrong. The following is my implementation and use case; Is anything wrong with my implementation?
import torch
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size):
super(LSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, bidirectional=True)
self.forward_lstm = nn.LSTM(input_size, hidden_size, bidirectional=False)
self.backward_lstm = nn.LSTM(input_size, hidden_size, bidirectional=False)
def bidirectional_lstm(self, input_data):
#print("===self.lstm: ", self.lstm.state_dict().items().keys())
output, (h_, c_) = self.lstm(input_data)
return output, h_, c_
def unidirectional_lstms(self, input_data):
f_params = self.forward_lstm.state_dict()
b_params = self.backward_lstm.state_dict()
forward_params = ['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0']
for k, v in self.lstm.state_dict().items():
print("====k: ", k)
if k in forward_params:
f_params[k] = v
elif k == 'weight_ih_l0_reverse':
b_params['weight_ih_l0'] = v
elif k == 'weight_hh_l0_reverse':
b_params['weight_hh_l0'] = v
elif k == 'bias_ih_l0_reverse':
b_params['bias_ih_l0'] = v
elif k == 'bias_hh_l0_reverse':
b_params['bias_hh_l0'] = v
else:
print("no")
self.forward_lstm.load_state_dict(f_params)
self.backward_lstm.load_state_dict(b_params)
input_data_reverse = torch.flip(input_data, [0])
output_forward, (f_h_, f_c_) = self.forward_lstm(input_data)
output_backward, (b_h_, b_c_) = self.backward_lstm(input_data_reverse)
output_backward = torch.flip(output_backward, [0])
output = torch.cat((output_forward, output_backward), dim=2)
return output, f_h_, f_c_, b_h_, b_c_
def lstm_with_cache(self, input_data, f_h, f_c, b_h, b_c):
f_params = self.forward_lstm.state_dict()
b_params = self.backward_lstm.state_dict()
forward_params = ['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0']
for k, v in self.lstm.state_dict().items():
print("====k: ", k)
if k in forward_params:
f_params[k] = v
elif k == 'weight_ih_l0_reverse':
b_params['weight_ih_l0'] = v
elif k == 'weight_hh_l0_reverse':
b_params['weight_hh_l0'] = v
elif k == 'bias_ih_l0_reverse':
b_params['bias_ih_l0'] = v
elif k == 'bias_hh_l0_reverse':
b_params['bias_hh_l0'] = v
else:
print("no")
self.forward_lstm.load_state_dict(f_params)
self.backward_lstm.load_state_dict(b_params)
input_data_reverse = torch.flip(input_data, [0])
output_forward, (f_h_, f_c_) = self.forward_lstm(input_data, (f_h, f_c))
output_backward, (b_h_, b_c_) = self.backward_lstm(input_data_reverse, (b_h, b_c))
output_backward = torch.flip(output_backward, [0])
output = torch.cat((output_forward, output_backward), dim=2)
return output, f_h_, f_c_, b_h_, b_c_
# input data
torch.manual_seed(0)
input_data = torch.randn(10, 2, 5)
input_size = 5
hidden_size = 10
lstm_model = LSTMModel(input_size, hidden_size)
output_with_bidirectional_lstm, bi_h, bi_c = lstm_model.bidirectional_lstm(input_data)
# 1st
first_out, first_f_h, first_f_c, first_b_h, first_b_c= lstm_model.unidirectional_lstms(input_data[:, 0:1, :])
# 2nd, use last h\c tensor
second_out, second_f_h, second_f_c, second_b_h, second_b_c = lstm_model.lstm_with_cache(input_data[:, 1:2, :], first_f_h, first_f_c, first_b_h, first_b_c)
# verify
fisrt_is_same = torch.allclose(output_with_bidirectional_lstm[:, 0:1, :], first_out, 1e-5)
second_is_same = torch.allclose(output_with_bidirectional_lstm[:, 1:2, :], second_out, 1e-5)
print("first:", fisrt_is_same)
print("second:", second_is_same)
I try:
- flip the h, c tensor, but useless.
- compare h,c between nn.LSTM(bidirectional=True) and nn.LSTM(bidirectional=True), is equal.
I expect:
- find out the reason why