1

I'm trying to train an RNN for machine translation, using LSTM. However,the BLEU at the first batch decreases to zero and stay at this level during all the training. At the same time loss is drastically decreasing. What may be the problem?

**the code: **

class SimpleRNNTranslator(nn.Module):
    def __init__(self, inp_voc, out_voc, emb_size=64, hid_size=128):
        """
        My version of simple RNN model, I use LSTM instead of GRU as in the baseline
        """
        super().__init__()
        
        self.inp_voc = inp_voc
        self.out_voc = out_voc
        
        self.emb_inp = nn.Embedding(len(inp_voc), emb_size)
        self.emb_out = nn.Embedding(len(out_voc), emb_size)
        
        self.encoder = nn.LSTM(emb_size, hid_size, batch_first=True)
        self.decoder = nn.LSTM(emb_size, hid_size, batch_first=True)
        
        self.decoder_start = nn.Linear(hid_size, hid_size)
        self.logits = nn.Linear(hid_size, len(out_voc))
        
    def forward(self, inp, out):
        """
        Apply model in training mode
        """
        encoded_seq = self.encode(inp)
        decoded_seq, _ = self.decode(encoded_seq, out)
        return self.logits(decoded_seq)
    
    def encode(self, seq_in):
        """
        Take input symbolic sequence, compute initial hidden state for decoder
        :param seq_in: matrix of input tokens [batch_size, seq_in_len]
        :return: initial hidden state for the decoder
        """
        embeddings = self.emb_inp(seq_in)
        output, (_, __) = self.encoder(embeddings)
    
        # last state isn't the actually last because of the padding, the next 2 lines find out the true last state
        seq_lengths = (seq_in != self.inp_voc.eos_ix).sum(dim=-1)
        
        last_states = output[range(seq_lengths.shape[0]), seq_lengths]
        
        return self.decoder_start(last_states)
    
    def decode(self, hidden_state, seq_out, previous_state=None):
        """
        Take output symbolic sequence, compute logits for every token in sequence
        :param hidden_state: matrix of initial_hidden_state [batch_size, hid_size]
        :param previous_state: matrix of previous state [batch_size, hid_size]
        :param seq_out: matrix of output tokens [batch_size, seq_out_len]
        :return: logits for every token in sequence [batch_size, seq_len, out_voc]
        """
        if not torch.is_tensor(previous_state):
            previous_state = torch.randn(*hidden_state.shape).to(device)
            
        embeddings = self.emb_out(seq_out)
        outputs, (_, cn) = self.decoder(embeddings, (hidden_state[None, :, :], previous_state[None, :, :]))
        
        return outputs, cn
    
    def inference(self, inp_tokens, max_len):
        """
        Take initial state and return ids for out words
        :param initial_state: initial_state for a decoder, produced by encoder with input tokens
        """
        initial_state = self.encode(inp_tokens)
        states = [initial_state]
        outputs = [torch.full([initial_state.shape[0]], self.out_voc.bos_ix, dtype=torch.int, device=device)]
        
        cn = None
        
        for i in range(100):
            hidden_state, cn = self.decode(states[-1], outputs[-1][:, None], previous_state=cn)
            hidden_state, cn = hidden_state.squeeze(), cn.squeeze()
            outputs.append(self.logits(hidden_state).argmax(dim=-1))
            states.append(hidden_state)

        
        return torch.stack(outputs, dim=-1), torch.cat(states)
            
    
    def translate_lines(self, lines, max_len=100):
        """
        Take lines and return translation
        :param lines: list of lines in Russian
        """
        inp_tokens = self.inp_voc.to_matrix(lines).to(device)
        out_ids, states = self.inference(inp_tokens, max_len=max_len)
        return self.out_voc.to_lines(out_ids.cpu().numpy()), states


**How I compute BLEU: **
from nltk.translate.bleu_score import corpus_bleu
def compute_bleu(model, inp_lines, out_lines, bpe_sep='@@ ', **flags):
    """
    Estimates corpora-level BLEU score of model's translations given inp and reference out
    Note: if you're serious about reporting your results, use https://pypi.org/project/sacrebleu
    """
    with torch.no_grad():
        translations, _ = model.translate_lines(inp_lines, **flags)
        translations = [line.replace(bpe_sep, '') for line in translations]
        actual = [line.replace(bpe_sep, '') for line in out_lines]
        return corpus_bleu(
            [[ref.split()] for ref in actual],
            [trans.split() for trans in translations],
            smoothing_function=lambda precisions, **kw: [p + 1.0 / p.denominator for p in precisions]
            ) * 100

Training, plots of BLEU score evaluated on development dataset and Loss Training, plots of BLEU score evaluated on development dataset and Loss

I had thoughts that this problem may be related to the how LSTM works. At first, I didn't pass a cell state during the elements of sequence, only hidden state. I fixed this, however it didn't resolved the issue

Haski
  • 11
  • 1

1 Answers1

1

You probably forgot to shift the target sequence when computing the loss.

At the training time, the decoder sequence needs to be shifted such that (n-1)-th predicts n-th word. For sequence w1 w2 w3 w4 with beginning-of-sentence token [BOS] and end-of-sentence token [EOS] like this:

BOS w1  w2  w3  w4
↓   ↓   ↓   ↓   ↓
▯ → ▯ → ▯ → ▯ → ▯  
↓   ↓   ↓   ↓   ↓
w1  w2  w3  w4  EOS 

Generally speaking: you feed the decoder with the target sequence without the last token and compute the loss with respect to the target sequence without the first token.

When you don't do th is, the decoder looks like this:

w1  w2  w3  w4
↓   ↓   ↓   ↓
▯ → ▯ → ▯ → ▯
↓   ↓   ↓   ↓
w1  w2  w3  w4

The model quickly learns to copy the input tokens, and the loss rapidly decreases, but the model does not learn to translate.

Jindřich
  • 10,270
  • 2
  • 23
  • 44