I wish to use Allen NLP Interpret for integrated visualization and Saliency mapping.on custom transformer model, can you please tell me how to do that?
Asked
Active
Viewed 541 times
1 Answers
5
It can be done by having AllenNLP wrappers around your custom model. The interpret modules require a Predictor object, so you can write your own, or use an existing one.
Here's an example for a classification model:
from allennlp.data.vocabulary import Vocabulary
from allennlp.predictors.text_classifier import TextClassifierPredictor
from allennlp.data.dataset_readers import TextClassificationJsonReader
import torch
class ModelWrapper(Model):
def __init__(self, vocab, your_model):
super().__init__(vocab)
self.your_model = your_model
self.logits_to_probs = torch.nn.Softmax()
self.loss = torch.nn.CrossEntropyLoss()
def forward(self, tokens, label=None):
if label is not None:
outputs = self.your_model(tokens, label=label)
else:
outputs = self.your_model(tokens)
probs = self.logits_to_probs(outputs["logits"])
if label is not None:
loss = self.loss(outputs["logits"], label)
outputs["loss"] = loss
outputs["probs"] = probs
return outputs
Your custom transformer model may not have an identifiable TextFieldEmbedder
. This is the initial embedding layer of your model, against which gradients are calculated for the saliency interpreters. These can be specified by overriding the following methods in the Predictor.
class PredictorWrapper(TextClassifierPredictor):
def get_interpretable_layer(self):
return self._model.model.bert.embeddings.word_embeddings # This is the initial layer for huggingface's `bert-base-uncased`; change according to your custom model.
def get_interpretable_text_field_embedder(self):
return self._model.model.bert.embeddings.word_embeddings
predictor = PredictorWrapper(model=ModelWrapper(vocab, your_model),
dataset_reader=TextClassificationJsonReader())
Now you have an AllenNLP predictor, which can be used with the interpret module as follows:
from allennlp.interpret.saliency_interpreters import SimpleGradient
interpreter = SimpleGradient(predictor)
interpreter.saliency_interpret_from_json({"sentence": "This is a good movie."})
This should give you the gradients with respect to each input token.

akshitab
- 266
- 1
- 2