I'm trying to build a Pytorch network for image captioning.
Currently I have a working network of Encoder and Decoder, and I want to add nn.MultiheadAttnetion
layer to it (to be used as self attention).
Currently my decode looks like this:
class Decoder(nn.Module):
def __init__(self, hidden_size, embed_dim, vocab_size, layers = 1):
super(Decoder, self).__init__()
self.embed_dim = embed_dim
self.vocab_size = vocab_size
self.layers = layers
self.hidden_size = hidden_size
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(input_size = embed_dim, hidden_size = hidden_size, batch_first = True, num_layers = layers)
#self.attention = nn.MultiheadAttention(hidden_size, num_heads=1, batch_first= True)
self.fc = nn.Linear(hidden_size, self.vocab_size)
def init_hidden(self, batch_size):
h = torch.zeros(self.layers, batch_size, self.hidden_size).to(device)
c = torch.zeros(self.layers, batch_size, self.hidden_size).to(device)
return h,c
def forward(self, features, caption):
batch_size = caption.size(0)
caption_size = caption.size(1)
h,c = self.init_hidden(batch_size)
embeddings = self.embedding(caption)
lstm_input = torch.cat((features.unsqueeze(1), embeddings[:,:-1,:]), dim=1)
output, (h,c) = self.lstm(lstm_input, (h,c))
#output, _ = self.attention(output, output, output)
output = self.fc(output)
return output
def generate_caption(self, features, max_caption_size = MAX_LEN):
h,c = self.init_hidden(1)
caption = ""
embeddings = features.unsqueeze(1)
for i in range(max_caption_size):
output, (h, c) = self.lstm(embeddings, (h,c))
#output, _ = self.attention(output, output, output)
output = self.fc(output)
_, word_index = torch.max(output, dim=2) # take the word with highest probability
if word_index == vocab.get_index(END_WORD):
break
caption += vocab.get_word(word_index) + " "
embeddings = self.embedding(torch.LongTensor([word_index]).view(1,-1).to(device))
return caption
and it gives relatively good results for image captioning. I want to add the commented out lines so the model will use Attention. But- when I do that- the model breaks, although the loss becomes extremely low (decreasing from 2.7 to 0.2 during training instead of 2.7 to 1 without the attention) - the caption generation is not really working (predicts the same word over and over again).
My questions are:
- Am I using the
nn.MultiheadAttention
correctly? it is very weird to me that it should be used after the LSTM, but I saw this online, and it works from dimension sizes perspective - Any idea why my model breaks when I use Attention?
EDIT: I also tried to put the Attention before the LSTM, and it didn't work as well (network predicted the same caption for every picture)