I'm trying to train an RNN for machine translation, using LSTM. However,the BLEU at the first batch decreases to zero and stay at this level during all the training. At the same time loss is drastically decreasing. What may be the problem?
**the code: **
class SimpleRNNTranslator(nn.Module):
def __init__(self, inp_voc, out_voc, emb_size=64, hid_size=128):
"""
My version of simple RNN model, I use LSTM instead of GRU as in the baseline
"""
super().__init__()
self.inp_voc = inp_voc
self.out_voc = out_voc
self.emb_inp = nn.Embedding(len(inp_voc), emb_size)
self.emb_out = nn.Embedding(len(out_voc), emb_size)
self.encoder = nn.LSTM(emb_size, hid_size, batch_first=True)
self.decoder = nn.LSTM(emb_size, hid_size, batch_first=True)
self.decoder_start = nn.Linear(hid_size, hid_size)
self.logits = nn.Linear(hid_size, len(out_voc))
def forward(self, inp, out):
"""
Apply model in training mode
"""
encoded_seq = self.encode(inp)
decoded_seq, _ = self.decode(encoded_seq, out)
return self.logits(decoded_seq)
def encode(self, seq_in):
"""
Take input symbolic sequence, compute initial hidden state for decoder
:param seq_in: matrix of input tokens [batch_size, seq_in_len]
:return: initial hidden state for the decoder
"""
embeddings = self.emb_inp(seq_in)
output, (_, __) = self.encoder(embeddings)
# last state isn't the actually last because of the padding, the next 2 lines find out the true last state
seq_lengths = (seq_in != self.inp_voc.eos_ix).sum(dim=-1)
last_states = output[range(seq_lengths.shape[0]), seq_lengths]
return self.decoder_start(last_states)
def decode(self, hidden_state, seq_out, previous_state=None):
"""
Take output symbolic sequence, compute logits for every token in sequence
:param hidden_state: matrix of initial_hidden_state [batch_size, hid_size]
:param previous_state: matrix of previous state [batch_size, hid_size]
:param seq_out: matrix of output tokens [batch_size, seq_out_len]
:return: logits for every token in sequence [batch_size, seq_len, out_voc]
"""
if not torch.is_tensor(previous_state):
previous_state = torch.randn(*hidden_state.shape).to(device)
embeddings = self.emb_out(seq_out)
outputs, (_, cn) = self.decoder(embeddings, (hidden_state[None, :, :], previous_state[None, :, :]))
return outputs, cn
def inference(self, inp_tokens, max_len):
"""
Take initial state and return ids for out words
:param initial_state: initial_state for a decoder, produced by encoder with input tokens
"""
initial_state = self.encode(inp_tokens)
states = [initial_state]
outputs = [torch.full([initial_state.shape[0]], self.out_voc.bos_ix, dtype=torch.int, device=device)]
cn = None
for i in range(100):
hidden_state, cn = self.decode(states[-1], outputs[-1][:, None], previous_state=cn)
hidden_state, cn = hidden_state.squeeze(), cn.squeeze()
outputs.append(self.logits(hidden_state).argmax(dim=-1))
states.append(hidden_state)
return torch.stack(outputs, dim=-1), torch.cat(states)
def translate_lines(self, lines, max_len=100):
"""
Take lines and return translation
:param lines: list of lines in Russian
"""
inp_tokens = self.inp_voc.to_matrix(lines).to(device)
out_ids, states = self.inference(inp_tokens, max_len=max_len)
return self.out_voc.to_lines(out_ids.cpu().numpy()), states
**How I compute BLEU: **
from nltk.translate.bleu_score import corpus_bleu
def compute_bleu(model, inp_lines, out_lines, bpe_sep='@@ ', **flags):
"""
Estimates corpora-level BLEU score of model's translations given inp and reference out
Note: if you're serious about reporting your results, use https://pypi.org/project/sacrebleu
"""
with torch.no_grad():
translations, _ = model.translate_lines(inp_lines, **flags)
translations = [line.replace(bpe_sep, '') for line in translations]
actual = [line.replace(bpe_sep, '') for line in out_lines]
return corpus_bleu(
[[ref.split()] for ref in actual],
[trans.split() for trans in translations],
smoothing_function=lambda precisions, **kw: [p + 1.0 / p.denominator for p in precisions]
) * 100
Training, plots of BLEU score evaluated on development dataset and Loss Training, plots of BLEU score evaluated on development dataset and Loss
I had thoughts that this problem may be related to the how LSTM works. At first, I didn't pass a cell state during the elements of sequence, only hidden state. I fixed this, however it didn't resolved the issue