Would it make sense to set these weights to zeros?
As you said, these tokens are ignored during the self_attention calculation, therefore, it doesn't make a difference to make them zero.
Let's have a look at the relevant code of Bert as an example:
from transformers import BertTokenizer, BertModel, BertConfig
import torch
sample = "This is"
model_id = 'bert-base-uncased'
t = BertTokenizer.from_pretrained(model_id)
c = BertConfig.from_pretrained(model_id)
c.num_attention_heads = 1
c.num_hidden_layers=1
m = BertModel.from_pretrained(model_id,config=c)
encoded_input = t(sample, padding='max_length', max_length=5, return_tensors='pt')
print(encoded_input)
The model input consists of 5 tokens (BOS token, two text tokens, EOS token, and padding token):
{'input_ids': tensor([[ 101, 2023, 2003, 102, 0]]),
'token_type_ids': tensor([[0, 0, 0, 0, 0]]),
'attention_mask': tensor([[1, 1, 1, 1, 0]])}
The attention_mask tells our model, that the first 4 tokens should attend to each other and the fifth token should be ignored. Bert does not use the attention mask as it is, it converts it to an extended_attention_mask:
extended_attention_mask = m.get_extended_attention_mask(encoded_input['attention_mask'], encoded_input['input_ids'].shape)
print(extended_attention_mask)
extended_attention_mask
has negative infinite (e.g. float32.min) for every token which should not be taken into account during self-attention calculation and zero otherwise (code):
# Please note the values depend on your machine you might see different numbers for negative infinite
tensor([[[[ -0., -0., -0., -0., -10000.]]]])
It is applied before the softmax is calculated from the QK^T
-product (code) and adds negative infinite to padding attention scores. Due to the huge difference in the individual values, the following softmax will assign zero to the padding attention scores:
attention_scores = torch.tensor([[[[ 9.5116e+00, 2.4427e-01, -1.1232e+00, 1.2221e+00, -1.0003e+04],
[ 6.4593e+00, 5.6316e+00, 6.7172e+00, 7.7484e+00, -9.9928e+03],
[ 4.6683e+00, 8.1287e+00, 6.1758e+00, 7.5101e+00, -9.9916e+03],
[ 1.0366e+01, 8.0461e+00, 7.5019e+00, 9.2650e+00, -9.9944e+03],
[ 1.2470e+01, 4.6752e+00, 5.9156e+00, 9.9091e+00, -9.9891e+03]]]])
attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)
print(attention_probs)
Output:
tensor([[[[9.9963e-01, 9.4426e-05, 2.4055e-05, 2.5105e-04, 0.0000e+00],
[1.5721e-01, 6.8711e-02, 2.0347e-01, 5.7061e-01, 0.0000e+00],
[1.8351e-02, 5.8412e-01, 8.2864e-02, 3.1466e-01, 0.0000e+00],
[6.7211e-01, 6.6057e-02, 3.8333e-02, 2.2350e-01, 0.0000e+00],
[9.2672e-01, 3.8169e-04, 1.3195e-03, 7.1576e-02, 0.0000e+00]]]])
Even when you set the padding embedding tensor to zero, the difference to the other values is still so high, that it won't make a difference.