I am trying to fine-tune an STS model for Textual Entailment classification for "entailment", "neutral", and "contradiction".
Here is the source code available on HuggingFace Sentence-BERT NLI: https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/nli/training_nli_v2.py
Part of the NLI training data: https://huggingface.co/datasets/snli
The script aims to create an NLI model based on the Semantic Textual Similarity benchmark with the code below:
# Save the path of the model
model_save_path = 'output/training_nli_v2_'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
My question is how I can use the output model to predict 0(neutral), 1(entailment), or 2(contradiction).
Would it be like this? ref. https://huggingface.co/facebook/bart-large-mnli
from transformers import AutoModelForSequenceClassification, AutoTokenizer
nli_model = AutoModelForSequenceClassification.from_pretrained('training_nli_v2_{model_name}')
tokenizer = AutoTokenizer.from_pretrained('training_nli_v2_{model_name}')
premise = "I am thirsty"
hypothesis = "I want water"
x = tokenizer.encode(premise, hypothesis, return_tensors='pt',
truncation_strategy='only_first')
logits = nli_model(x.to(device))[0]
# We throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
prob_label_is_true = probs[:,1]
Any help would be greatly appreciated.