2

I'm trying to reload a DistilBertForSequenceClassification model I've fine-tuned and use that to predict some sentences into their appropriate labels (text classification).

In google Colab, after successfully training the BERT model, I downloaded it after saving:

trainer.train()
trainer.save_model("distilbert_classification")

The downloaded model has three files: config.json, pytorch_model.bin, training_args.bin.

I moved them encased in a folder named 'distilbert_classification' somewhere in my google drive.

afterwards, I reloaded the model in a different Colab notebook:


reloadtrainer = DistilBertForSequenceClassification.from_pretrained('google drive directory/distilbert_classification')

Up to this point, I have succeeded without any errors.

However, how to I use this reloaded model (the 'reloadtrainer' object) to actually make the predictions on sentences? What is the code I need to use afterwards? I tried

reloadtrainer .predict("sample sentence") but it doesn't work. Would appreciate any help!

Timbus Calin
  • 13,809
  • 5
  • 41
  • 59
Robin311
  • 203
  • 1
  • 9

1 Answers1

1

Remember that you also need to tokenize the input to your model, just like in the training phase. Merely feeding a sentence to the model will not work (unless you use pipelines() but that's another discussion).

You may use an AutoModelForSequenceClassification() and AutoTokenizer() to make things easier.

Note that the way I am saving the model is via model.save_pretrained("path_to_model") rather than model.save().

One possible approach could be the following (say you trained with uncased distilbert):

  model = AutoModelForSequenceClassification.from_pretrained("path_to_model")
  # Replace with whatever tokenizer you used
  tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True)
  input_text = "This is the text I am trying to classify."
  tokenized_text = tokenizer(input_text,
                             truncation=True,
                             is_split_into_words=False,
                             return_tensors='pt')
  outputs = model(tokenized_text["input_ids"])
  predicted_label = outputs.logits.argmax(-1)
Timbus Calin
  • 13,809
  • 5
  • 41
  • 59