I have a sequence to sequence POS tagging model which uses Transformer decoder to generate target tokens. My implementation of Pytorch's Transformer decoder is as follows:
in the initialization:
self.decoder_layer = nn.TransformerDecoderLayer(d_model=ENV_HIDDEN_SIZE, nhead=2,batch_first=True,dim_feedforward=300 ,activation="relu")
self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=2)
and in the forward function:
if infer==False: # for training
embedded=embedded*math.sqrt(ENV_HIDDEN_SIZE)
embedded = self.pos_encoder(embedded)
zol = self.transformer_decoder(tgt=embedded,memory=newtensor
,memory_mask=self.transformer_mask
,memory_key_padding_mask=x_mask
,tgt_mask=self.transformer_mask)
scores = self.slot_trans(self.dropout3(zol))
else: #for inferrence
bos = Variable(torch.LongTensor([[tag2index['<BOS>']]*batch_size])).cuda().transpose(1,0)
bos = self.embedding(bos)
tokens=bos
for i in range(length):
temp_embedded=tokens*math.sqrt(ENV_HIDDEN_SIZE)
temp_embedded = self.pos_encoder(temp_embedded)
zol = self.transformer_decoder(tgt=temp_embedded,
memory=newtensor,
tgt_mask=self.transformer_mask[:i+1,:i+1],
memory_key_padding_mask=x_mask,
memory_mask=self.transformer_mask[:i+1,:]
)
scores = self.slot_trans(self.dropout3(zol))
softmaxed = self.softmax(scores)
_,input = torch.max(softmaxed,2)
newtok = self.embedding(input)
tokens=torch.cat((bos,newtok),dim=1)
and the memory_mask is generated by the function "generate_square_subsequent_mask" as given:
def generate_square_subsequent_mask(sz: int) :
"""Generates an upper-triangular matrix of -inf, with zeros on diag."""
return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
I am observing something weird. If I do not feed the memory_mask with generate_subsequent_mask -which I should not according to this post-, the accuracy severely decreases. Furthermore, accuracy of the model fluctuates between 50% and 90% on each epoch randomly on the test set but not the training set. if I do feed the memory_mask, everything is fine, and model accuracy steadily increases to 95% on the test set. Moreover, the final accuracy takes a hit when not feeding the memory_mask. Things I tried:
- Without memory_mask: Tuning the learning rate.
- Without memory_mask: Increasing the nhead and num_layers.
- Using a simple linear layer.
At the end-note, using a simple linear layer instead of the transformer decoder provides a better accuracy. Any ideas as to why this is happening?