I've been trying to find a library or an example for getting token importance when a BERT model predicts a masked span, eg:
from transformers import BertTokenizerFast, BertForMaskedLM
import torch
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()
text = 'Brad Pitt is an [MASK] actor.'
tokenized_text = tokenizer.tokenize(text)
masked_index = tokenized_text.index("[MASK]")
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])
# Predict all tokens
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0]
probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
You could then pick the highest predicted value, or the 5 top values.
How would I go about calculating, let's say vanilla gradients or any other type or saliency method and see which tokens where important when predicting the masked token?
I read Ecco's documentation but they don't support attribution for BERT yet, AllenNLP has a demo for MLM task, but it's only for that demo, and I couldn't find anything relevant using SHAP or Captum.
Any help pointing to the right direction woudl be appreciated.