I am trying to use and learn PyTorch Transformer with DeepMind math dataset. I have tokenized (char not word) sequence that is fed into model. Models forward function is doing once forward for encoder and multiple forwards for decoder (till all batch outputs reach token, this is still TODO). I am struggling with Transformer masks and decoder forward as it throws the error:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
RuntimeError: shape '[-1, 24, 64]' is invalid for input of size 819200.
Source is N = 32, S = 50, E = 512. Target is N = 32, S = 3, E = 512. It is possible that I have wrong implementation of masks or that source and target lengths are different, not realy sure.
class PositionalEncoding(nn.Module):
# function to positionally encode src and target sequencies
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class MyTransformerModel(nn.Module):
# should implement init and forward function
# define separate functions for masks
# define forward function with
# implement:
# embedding layer
# positional encoding
# encoder layer
# decoder layer
# final classification layer
# encoder -> forward once
# decoder -> forward multiple times (for one encoder forward)
# decoder output => concatenate to input e.g. decoder_input = torch.cat([decoder_input], [decoder_output])
# early stopping => all in batch reach <eos> token
def __init__(self, vocab_length = 30, sequence_length = 512, num_encoder_layers = 3, num_decoder_layers = 2, num_hidden_dimension = 256, feed_forward_dimensions = 1024, attention_heads = 8, dropout = 0.1, pad_idx = 3, device = "CPU", batch_size = 32):
super(MyTransformerModel, self).__init__()
self.src_embedding = nn.Embedding(vocab_length, sequence_length)
self.pos_encoder = PositionalEncoding(sequence_length, dropout)
self.src_mask = None # attention mask
self.memory_mask = None # attention mask
self.pad_idx = pad_idx
self.device = device
self.batch_size = batch_size
self.transformer = nn.Transformer(
sequence_length,
attention_heads,
num_encoder_layers,
num_decoder_layers,
feed_forward_dimensions,
dropout,
)
def src_att_mask(self, src_len):
mask = (torch.triu(torch.ones(src_len, src_len)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def no_peak_att_mask(self, batch_size, src_len, time_step):
mask = np.zeros((batch_size, src_len), dtype=bool)
mask[:, time_step: ] = 1 # np.NINF
mask = torch.from_numpy(mask)
return mask
def make_src_key_padding_mask(self, src):
# mask "<pad>"
src_mask = src.transpose(0, 1) == self.pad_idx
return src_mask.to(self.device)
def make_trg_key_padding_mask(self, trg):
tgt_mask = trg.transpose(0, 1) == self.pad_idx
return tgt_mask.to(self.device)
def forward(self, src, trg):
src_seq_length, N = src.shape
trg_seq_length, N = trg.shape
embed_src = self.src_embedding(src)
position_embed_src = self.pos_encoder(embed_src)
embed_trg = self.src_embedding(trg)
position_embed_trg = self.pos_encoder(embed_trg)
src_padding_mask = self.make_src_key_padding_mask(src)
trg_padding_mask = self.make_trg_key_padding_mask(trg)
trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_length).to(self.device)
time_step = 1
att_mask = self.no_peak_att_mask(self.batch_size, src_seq_length, time_step).to(self.device)
encoder_output = self.transformer.encoder.forward(position_embed_src, src_key_padding_mask = src_padding_mask)
# TODO : implement loop for transformer decoder forward fn, implement early stopping
# where to feed decoder_output?
decoder_output = self.transformer.decoder.forward(position_embed_trg, encoder_output, trg_mask, att_mask, trg_padding_mask, src_padding_mask)
return decoder_output
Can anyone pin point where I have made a mistake?