0

When using GPT2 we can simply pass on the 'labels' parameter to get the loss as follows:

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=True)

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss

But, not able to find out how to get the same loss in an ONNX inference session. I am using the below code which only returns the 'last_hidden_state':

import onnxruntime as ort

from transformers import GPT2TokenizerFast
#tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

ort_session = ort.InferenceSession("onnx/gpt2/model.onnx")

inputs = tokenizer("Using BERT in ONNX!", return_tensors="np")
outputs = ort_session.run(["last_hidden_state"], dict(inputs))
Sergii Dymchenko
  • 6,890
  • 1
  • 21
  • 46

1 Answers1

0

How "onnx/gpt2/model.onnx" was generated?

It looks like while the PyTorch run uses transformers.GPT2LMHeadModel, the ORT run uses transformers.GPT2Model, which is a "bare GPT2 Model transformer outputting raw hidden-states without any specific head on top" and doesn't return loss.

Sergii Dymchenko
  • 6,890
  • 1
  • 21
  • 46
  • Hi, I used MyGPT2LMHeadModel in Gpt2Helper.py to generate 'model.onnx' and saved it using export_onnx method. ``` from onnxruntime.transformers.gpt2_helper import Gpt2Helper, MyGPT2LMHeadModel from transformers import AutoConfig model_name_or_path = "gpt2" config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) model = MyGPT2LMHeadModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir) device = torch.device("cpu") model.eval().to(device) onnx_model_path = "onnx/gpt2/gpt2.onnx" Gpt2Helper.export_onnx(model, device, onnx_model_path) ``` – Sachin Saxena Aug 13 '21 at 17:19
  • 1
    I was able to make a change in GPTLMHeadModel's forward() method to send 'labels = input_ids' which returned "loss" as the first output and it helped solve my problem. – Sachin Saxena Aug 19 '21 at 17:31