0

Two nn.lstm(bidirectional=False) are used in pytorch to realize the function of a bidirectional lstm(nn.lstm(bidirectional=True)); Now the input is a stream sequence, and the result is found to be wrong. The following is my implementation and use case; Is anything wrong with my implementation?

import torch
import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, bidirectional=True)
        self.forward_lstm = nn.LSTM(input_size, hidden_size, bidirectional=False)
        self.backward_lstm = nn.LSTM(input_size, hidden_size, bidirectional=False)

    def bidirectional_lstm(self, input_data):
        #print("===self.lstm: ", self.lstm.state_dict().items().keys())
        
        
        output, (h_, c_) = self.lstm(input_data)
        return output, h_, c_

    def unidirectional_lstms(self, input_data):
        f_params = self.forward_lstm.state_dict()
        b_params = self.backward_lstm.state_dict()
        forward_params = ['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0']
        for k, v in self.lstm.state_dict().items():
            print("====k: ", k)
            if k in forward_params:
                f_params[k] = v
            elif k == 'weight_ih_l0_reverse':
                b_params['weight_ih_l0'] = v
            elif k == 'weight_hh_l0_reverse':
                b_params['weight_hh_l0'] = v
            elif k == 'bias_ih_l0_reverse':
                b_params['bias_ih_l0'] = v
            elif k == 'bias_hh_l0_reverse':
                b_params['bias_hh_l0'] = v
            else:
                print("no")
        

        self.forward_lstm.load_state_dict(f_params)
        self.backward_lstm.load_state_dict(b_params)

        input_data_reverse = torch.flip(input_data, [0])

        output_forward, (f_h_, f_c_) = self.forward_lstm(input_data)
        output_backward, (b_h_, b_c_) = self.backward_lstm(input_data_reverse)
        output_backward = torch.flip(output_backward, [0])

        output = torch.cat((output_forward, output_backward), dim=2)

        return output, f_h_, f_c_, b_h_, b_c_



    def lstm_with_cache(self, input_data, f_h, f_c, b_h, b_c):
        f_params = self.forward_lstm.state_dict()
        b_params = self.backward_lstm.state_dict()
        forward_params = ['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0']
        for k, v in self.lstm.state_dict().items():
            print("====k: ", k)
            if k in forward_params:
                f_params[k] = v
            elif k == 'weight_ih_l0_reverse':
                b_params['weight_ih_l0'] = v
            elif k == 'weight_hh_l0_reverse':
                b_params['weight_hh_l0'] = v
            elif k == 'bias_ih_l0_reverse':
                b_params['bias_ih_l0'] = v
            elif k == 'bias_hh_l0_reverse':
                b_params['bias_hh_l0'] = v
            else:
                print("no")        

        self.forward_lstm.load_state_dict(f_params)
        self.backward_lstm.load_state_dict(b_params)

        input_data_reverse = torch.flip(input_data, [0])

        output_forward, (f_h_, f_c_) = self.forward_lstm(input_data, (f_h, f_c))
        output_backward, (b_h_, b_c_) = self.backward_lstm(input_data_reverse, (b_h, b_c))
        output_backward = torch.flip(output_backward, [0])

        output = torch.cat((output_forward, output_backward), dim=2)

        return output, f_h_, f_c_, b_h_, b_c_

# input data
torch.manual_seed(0)
input_data = torch.randn(10, 2, 5)

input_size = 5
hidden_size = 10
lstm_model = LSTMModel(input_size, hidden_size)


output_with_bidirectional_lstm, bi_h, bi_c = lstm_model.bidirectional_lstm(input_data)

# 1st
first_out, first_f_h, first_f_c, first_b_h, first_b_c= lstm_model.unidirectional_lstms(input_data[:, 0:1, :])

# 2nd, use last h\c tensor
second_out, second_f_h, second_f_c, second_b_h, second_b_c = lstm_model.lstm_with_cache(input_data[:, 1:2, :], first_f_h, first_f_c, first_b_h, first_b_c)

# verify
fisrt_is_same = torch.allclose(output_with_bidirectional_lstm[:, 0:1, :], first_out, 1e-5)

second_is_same = torch.allclose(output_with_bidirectional_lstm[:, 1:2, :], second_out, 1e-5)

print("first:", fisrt_is_same)
print("second:", second_is_same)

I try:

  1. flip the h, c tensor, but useless.
  2. compare h,c between nn.LSTM(bidirectional=True) and nn.LSTM(bidirectional=True), is equal.

I expect:

  1. find out the reason why

0 Answers0