0

I've been trying to find a library or an example for getting token importance when a BERT model predicts a masked span, eg:

from transformers import BertTokenizerFast, BertForMaskedLM
import torch

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

text = 'Brad Pitt is an [MASK] actor.'

tokenized_text = tokenizer.tokenize(text)
masked_index = tokenized_text.index("[MASK]")
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
tokens_tensor = torch.tensor([indexed_tokens])

# Predict all tokens
with torch.no_grad():
    outputs = model(tokens_tensor)
    predictions = outputs[0]

probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)

You could then pick the highest predicted value, or the 5 top values.

How would I go about calculating, let's say vanilla gradients or any other type or saliency method and see which tokens where important when predicting the masked token?

I read Ecco's documentation but they don't support attribution for BERT yet, AllenNLP has a demo for MLM task, but it's only for that demo, and I couldn't find anything relevant using SHAP or Captum.

Any help pointing to the right direction woudl be appreciated.

Paschalis
  • 191
  • 10

0 Answers0