I'm currently trying to predict upcoming words given an input text chunk, but I want to "mask" the last n words of the input text by setting the attention weights to 0 (or something very small).
This is what I tried to do:
I tried modifying the biases on all layers of my GPT-2 model by setting them to a very small value for all tokens I want to mask. I read that you add the bias values to the dot-product of query and key vectors, so I figured they have to be negative and ideally very small in order to make the resulting attention weight as small as possible. In a blog post on self-attention I read that you can use either -inf or -1 billion (in GPT), but if I use any values < -1, I get errors, possibly because I produce values that are so small that they produce underflow (although I think it's odd that -1 is basically the minimum cutoff value I can still use, that's not that small).
This is what I'd need advice on:
a) Does changing the biases like that make sense? I'm a newbie GPT-user so I'm always a little unsure whether what I do is correct.
b) If my approach makes sense, why can't I use values < -1? Is there a way to use smaller values?
c) If not and I use -1, would that still work to reduce the attention weights to something around 0?
This is my code:
# import modules
!pip install transformers
import numpy as np
import math
import tensorflow as tf
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# download pre-trained German GPT-2 model & tokenizer from the Hugging Face model hub
tokenizer = AutoTokenizer.from_pretrained("dbmdz/german-gpt2")
# initialise the model, use end-of-sequence (EOS) token as padding tokens
model = AutoModelForCausalLM.from_pretrained("dbmdz/german-gpt2", pad_token_id = tokenizer.eos_token_id)
# set the input text and the number of words to mask
input_text = ["Orlando", "liebte", "von", "Natur", "aus", "einsame", "Orte,", "weite", "Ausblicke", "und", "das", "Gefühl,", "für", "immer", "und", "ewig"] # allein zu sein.
n = 5 # I want to mask the last n words
print("masking the following", n, "words now:", )
# 1 word can consist of many tokens, so get last n words
# and tokenize them so we know how many tokens we have to mask:
masked_words = input_text[-n:]
n_tokens = len(tokenizer.tokenize(" ".join(masked_words)))
print(" ".join(masked_words))
# encode the full input text (including the words we want to mask) and get the attention mask
encoded_input = tokenizer.encode_plus(" ".join(input_text),
add_special_tokens = False, # don't add special tokens
return_attention_mask = True, # return the attention mask for the current text input
return_tensors='pt') # return output as PyTorch tensor object
# get attention mask from encoded input
attention_mask = encoded_input['attention_mask']
# check how many attention weights there are in the mask.
# mask_length should be number of tokens in the full sentence
mask_length = attention_mask.size()[1]
# Create new attention mask where the weights for the last n tokens are
# set to - 1 billion
# Mask the last n words by setting them to - 1 billion,
# but leave last token set to 1 (for the space after the last masked word)
# attention_mask[:, -(n_tokens + 1): -1] = -1 # this works
attention_mask[:, -(n_tokens + 1): -1] = -100000000 # this doesn't work
#print(attention_mask)
# Now we want to modify the attention weights on all layers by changing
# the biases to our attention mask values there.
# Loop modules in model.transformer
# (all attention layers are modules in the transformer)
# Find attention modules and set our custom attention mask as biases
for module in model.transformer.modules():
# if the current module is a MultiHeadAttention object (aka an attention module)...
if isinstance(module, torch.nn.MultiheadAttention):
# set attention mask we defined earlier as the biases
module.register_buffer("bias", attention_mask.unsqueeze(0))
# Get prediction from full model:
# use ids from input text (input_ids) & the modified attention mask to generate the output
output = model.generate(encoded_input['input_ids'],
attention_mask = attention_mask,
max_new_tokens = 10)
# get the predicted token ID and the corresponding text string
predicted_text = tokenizer.decode(output[0],
skip_special_tokens = True)
# print the predicted text
print("\n Predicted text:", predicted_text)
Thanks in advance for your help/ideas/comments!