I think I understand the use of src_key_padding_mask
thanks to Difference between src_mask and src_key_padding_mask. However, I was expecting the src_key_padding_mask
to cause the output to be zero, or negative infinity for the masked values. Just wondering if I am using it correctly, or if I need to modify the following snippet.
Please note that I know I need to use positional encoding and that I have not used it on purpose.
import random
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
random.seed(42)
torch.manual_seed(42)
DIM = 5
BATCH = 2
x = [torch.randn(random.randint(1, 3), DIM) for _ in range(2)]
mask = pad_sequence([torch.LongTensor([1]*len(elem)) for elem in x]) == 0
padded_x = pad_sequence(x)
encoder_layer = nn.TransformerEncoderLayer(d_model=DIM, nhead=1)
encoder = nn.TransformerEncoder(encoder_layer, num_layers=2).eval()
# the output of the following two are the same except for where it is masked. I was expecting zeros:
out1 = encoder(padded_x, src_key_padding_mask=mask.T)
out2 = encoder(padded_x)