25

I am having a difficult time in understanding transformers. Everything is getting clear bit by bit but one thing that makes my head scratch is what is the difference between src_mask and src_key_padding_mask which is passed as an argument in forward function in both encoder layer and decoder layer.

https://pytorch.org/docs/master/_modules/torch/nn/modules/transformer.html#Transformer

Leo
  • 480
  • 1
  • 6
  • 9
  • from the official pytorch forum: `The src_mask is just a square matrix which is used to filter the attention weights. ... src_key_padding_mask is more like a padding marker, which masks a specific tokens in the src sequence (a.k.a. the entire column/row of the attention matrix is set to '-inf').` https://discuss.pytorch.org/t/nn-transformer-explaination/53175/5 – Charlie Parker Jul 14 '21 at 21:06
  • another post https://discuss.pytorch.org/t/transformer-difference-between-src-mask-and-src-key-padding-mask/84024 from the official pytorch forum. – Charlie Parker Jul 14 '21 at 21:07
  • yet another one: https://www.reddit.com/r/pytorch/comments/okfh2k/transformer_difference_between_src_mask_and_src/ – Charlie Parker Jul 14 '21 at 22:37
  • related question: https://stackoverflow.com/questions/62399243/transformerencoder-with-a-padding-mask – Charlie Parker Jul 14 '21 at 22:44
  • perhaps reading the docs for MHA is the best...? https://pytorch.org/docs/master/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention – Charlie Parker Jul 14 '21 at 22:55

3 Answers3

32

Difference between src_mask and src_key_padding_mask

The general thing is to notice the difference between the use of the tensors _mask vs _key_padding_mask. Inside the transformer when attention is done we usually get an squared intermediate tensor with all the comparisons of size [Tx, Tx] (for the input to the encoder), [Ty, Ty] (for the shifted output - one of the inputs to the decoder) and [Ty, Tx] (for the memory mask - the attention between output of encoder/memory and input to decoder/shifted output).

So we get that this are the uses for each of the masks in the transformer (note the notation from the pytorch docs is as follows where Tx=S is the source sequence length (e.g. max of input batches), Ty=T is the target sequence length (e.g. max of target length), B=N is the batch size, D=E is the feature number):

  1. src_mask [Tx, Tx] = [S, S] – the additive mask for the src sequence (optional). This is applied when doing atten_src + src_mask. I'm not sure of an example input - see tgt_mask for an example but the typical use is to add -inf so one could mask the src_attention that way if desired. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  2. tgt_mask [Ty, Ty] = [T, T] – the additive mask for the tgt sequence (optional). This is applied when doing atten_tgt + tgt_mask. An example use is the diagonal to avoid the decoder from cheating. So the tgt is right shifted, the first tokens are start of sequence token embedding SOS/BOS and thus the first entry is zero while the remaining. See concrete example at the appendix. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  3. memory_mask [Ty, Tx] = [T, S]– the additive mask for the encoder output (optional). This is applied when doing atten_memory + memory_mask. Not sure of an example use but as previously, adding -inf sets some of the attention weight to zero. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  4. src_key_padding_mask [B, Tx] = [N, S] – the ByteTensor mask for src keys per batch (optional). Since your src usually has different lengths sequences it's common to remove the padding vectors you appended at the end. For this you specify the length of each sequence per example in your batch. See concrete example in appendix. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  5. tgt_key_padding_mask [B, Ty] = [N, t] – the ByteTensor mask for tgt keys per batch (optional). Same as previous. See concrete example in appendix. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

  6. memory_key_padding_mask [B, Tx] = [N, S] – the ByteTensor mask for memory keys per batch (optional). Same as previous. See concrete example in appendix. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.

Appendix

Examples from pytorch tutorial (https://pytorch.org/tutorials/beginner/translation_transformer.html):

1 src_mask example

    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

returns a tensor of booleans of size [Tx, Tx]:

tensor([[False, False, False,  ..., False, False, False],
         ...,
        [False, False, False,  ..., False, False, False]])

2 tgt_mask example

    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1)
    mask = mask.transpose(0, 1).float()
    mask = mask.masked_fill(mask == 0, float('-inf'))
    mask = mask.masked_fill(mask == 1, float(0.0))

generates the diagonal for the right shifted output which the input to the decoder.

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
         -inf, -inf, -inf],
         ...,
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.]])

usually the right shifted output has the BOS/SOS at the beginning and it's the tutorial gets the right shift simply by appending that BOS/SOS at the front and then triming the last element with tgt_input = tgt[:-1, :].

3 _padding

The padding is just to mask the padding at the end. The src padding is usually the same as the memory padding. The tgt has it's own sequences and thus it's own padding. Example:

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    memory_padding_mask = src_padding_mask

Output:

tensor([[False, False, False,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ...,  True,  True,  True]])

note that a False means there is no padding token there (so yes use that value in the transformer forward pass) and a True means that there is a padding token (so masked it out so the transformer pass forward does not get affected).


The answers are sort of spread around but I found only these 3 references being useful (the separate layers docs/stuff wasn't very useful honesty):

Charlie Parker
  • 5,884
  • 57
  • 198
  • 323
  • 1
    For case 3 which mask should we provide if they are the same? What is the difference between the memory and src padding masks? – Evan Zamir Jul 21 '21 at 02:02
  • @EvanZamir I don't want to give general statements because I do want to encourage research creativity in how models are used (and this case masks). That being said, the memory is usually the embedded src (from the transformer encoder) and thus usually shares the same mask as the src (both `src_mask` and `src_key_padding_mask `). For your case if your input is the same as your output (e.g. an autoencoder) then you can use the same mask (i.e. create one vector and assign it to both variables `tgt_padding_mask = src_padding_mask`). If you tell me what tasks your solving to help more. – Charlie Parker Jul 21 '21 at 12:13
  • I'm essentially building a model like word2vec but using just the Transformer Encoder module. So each sample is let's say up to 512 tokens but some samples could be much less and then have padding tokens obviously. Like BERT I replace 15% of tokens with a mask token and then try to predict the missing tokens. So I'm wondering now what mask I could use with the pad tokens. It seems like `src_key_padding_mask`? Do I just fill it with 1's and 0's? – Evan Zamir Jul 21 '21 at 16:48
  • 1
    @EvanZamir the `_key_padding_mask` - as used in the concrete print statement of my answer - shows that if there is a `True` then you are telling the pytorch transformer layer that there is a pad token there so it ignores it. I don't believe it actually matters what the content of the pad token is as long as you feed in this mask correctly. Anything with a `True` will be ignored and anything with a `False` will not. I added a few more comments to my answer (I suggest you to read them). In short, **focus on making the mask right** and your batch of sequences will be processed correctly. – Charlie Parker Jul 21 '21 at 19:43
  • Hi @CharlieParker thanks for detail comment. I have just one one question if `src_mask` is going to contain `False` values only then whats the use of it? as you wrote `while False values will be unchanged` – maq Oct 20 '21 at 16:40
  • @maq I believe it's because you are feeding entire tensors to an API that expects entire tensors. This is because the pytorch API receives tensors to speed up with GPUs, so when GPU computations are done, it knows in the GPU code what to change and what not to change for you. It's so pytorch is able to function and be compatible with cuda (is my guess). – Charlie Parker Oct 20 '21 at 18:06
  • @CharlieParker I have the same question for `src_mask`. It feels like `src_mask` can be computed from `src_key_padding_mask` by doing N times of "outer products" on the padding mask. Otherwise what other purpose does `src_mask` have other than marking the paddings in the square attention matrix? – CyberPlayerOne May 06 '23 at 12:13
6

I must say PyTorch implementations are a bit confusing as it contains too many mask parameters. But I can shed light on the two mask parameters that you are referring to. Both src_mask and src_key_padding_mask is used in the MultiheadAttention mechanism. According to the documentation of MultiheadAttention:

key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention.

attn_mask – 2D or 3D mask that prevents attention to certain positions.

As you know from the paper, Attention is all you need, MultiheadAttention is used in both Encoder and Decoder. However, in Decoder, there are two types of MultiheadAttention. One is called Masked MultiheadAttention and another one is the regular MultiheadAttention. To accommodate both these techniques, PyTorch uses the above mentioned two parameters in their MultiheadAttention implementation.

So, long story short-

  • attn_mask and key_padding_mask is used in Encoder's MultiheadAttention and Decoder's Masked MultiheadAttention.
  • memory_mask is used in Decoder's MultiheadAttention mechanism as pointed out here.

Looking into the implementation of MultiheadAttention might help you.

As you can see from here and here, first src_mask is used to block specific positions from attending and then key_padding_mask is used to block attending to pad tokens.

Note. Answer updated based on @michael-jungo's comment.

Community
  • 1
  • 1
Wasi Ahmad
  • 35,739
  • 32
  • 114
  • 161
  • 6
    The two points under *long story short* are not correct. Firstly, an `attn_mask` and a `key_padding_mask` are used in the self-attention (enc-enc and dec-dec) as well as the encoder-decoder attention (enc-dec). Secondly, PyTorch [doesn't use the `src_mask` in the decoder, but rather the `memory_mask`](https://github.com/pytorch/pytorch/blob/ec5d579929b2c56418aacaec0874b92937d095a4/torch/nn/modules/transformer.py#L124-L127) (they are often the same, but separate in the API). `src_mask` and `src_key_padding_mask` belong to the encoder's self-attention. The last sentence summarises it quite well – Michael Jungo Jun 03 '20 at 19:40
  • @wasi can you address the comment by Michael and correct your answer if it's incorrect? – Charlie Parker Jul 14 '21 at 21:04
3

To give a small example, consider I want to build a sequential recommender i.e., given the items the users have purchased till time 't' predict the next item at 't+1'

u1 - [i1, i2, i7]
u2 - [i2, i5]
u3 - [i6, i7, i1, i2]

For this task, I could use a transformer where I would make the sequence equal length by padding it with 0's on left.

u1 - [0,  i1, i2, i7]
u2 - [0,  0,  i2, i5]
u3 - [i6, i7, i1, i2]

I will use key_padding_mask to tell PyTorch that 0's shd be ignored. Now, consider user u3 where given [i6] I want to predict [i7] and later given [i6, i7] I want to predict [i1] i.e., I want causal attention, such that the attention doesn't peep into the future elements. For this, I will use attn_mask. Hence for user u3 attn_mask will be like

[[True, False, False, False],
 [True, True , False, False],
 [True, True , True , False]
 [True, True , True , True ]]
Sanjay
  • 169
  • 2
  • 9