This question is similar to How to deal with stack expects each tensor to be equal size eror while fine tuning GPT-2 model? but it's a little different since the other questions shows model training without the usage of Trainer
objects but this question uses it.
Additionally, there's a usage of a deprecated TextDataset
object, thus an answer could be justifiably warranted.
TL;DR
Firstly, the TextDataset
object is deprecated. And if you want a fuzz-free approach to tune a GPT model, see https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling
In Long
While the code for GPT2LMHeadModel
is on
the canonical example code for fine-tuning a language model is
The example uses wikitext
dataset as an example to run the CLI training, so to look into Python under the hood, we can start with:
from datasets import load_dataset
wikitext = load_dataset('wikitext')
wikitext
[out]:
DatasetDict({
test: Dataset({
features: ['text'],
num_rows: 4358
})
train: Dataset({
features: ['text'],
num_rows: 36718
})
validation: Dataset({
features: ['text'],
num_rows: 3760
})
})
If we look into the individual dataset split, e.g.
# The data on the 10th index.
wikitext['train'][10]
[out]:
{'text': ' The game \'s battle system , the <unk> system , is carried over directly from <unk> Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters \' turns . Each character has a field and distance of movement limited by their Action <unk> . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific " Potentials " , skills unique to each character . They are divided into " Personal Potential " , which are innate skills that remain unaltered unless otherwise dictated by the story and can either help or impede a character , and " Battle Potentials " , which are grown throughout the game and always grant <unk> to a character . To learn Battle Potentials , each character has a unique " Masters Table " , a grid @-@ based skill table that can be used to acquire and link different skills . Characters also have Special <unk> that grant them temporary <unk> on the battlefield : Kurt can activate " Direct Command " and move around the battlefield without <unk> his Action Point gauge , the character <unk> can shift into her " Valkyria Form " and become <unk> , while Imca can target multiple enemy units with her heavy weapon . \n'}
Given the knowledge of how the wikitext
dataset is structured, instead of using the old TextDataset
, to use your own custom data with the new generic Dataset
object, it'll be simply:
from datasets import Dataset
texts = ["hello world", 'fizz buzz and foo bar', 'hallo welt is hello world in German', 'foo bar bar bleh sheep']
my_dataset = Dataset.from_dict({'text': texts})
my_dataset[2]
[out]:
{'text': 'hallo welt is hello world in German'}
Now we have the same structure to the Dataset
object as per wikitext
. Then to fine-tune the model, we can simply reuse the rest of the code from https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py
Here's a minimal working example:
from itertools import chain
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import default_data_collator
from transformers import Trainer
from datasets import Dataset
import evaluate
# Load the GPT2 tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = AutoModelForCausalLM.from_pretrained('gpt2')
embedding_size = model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
MAX_LENGTH = tokenizer.model_max_length
# From https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py#L489
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples, block_size=MAX_LENGTH):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
# Training data.
texts = ["hello world", 'fizz buzz and foo bar', 'hallo welt is hello world in German', 'foo bar bar bleh sheep']
raw_datasets = Dataset.from_dict({'text': texts})
tokenized_datasets = raw_datasets.map(
lambda x: tokenizer(x['text']),
batched=True,
)
train_dataset = tokenized_datasets.map(
group_texts,
batched=True,
)
# Validation data.
eval_texts = ["hello world blah blah bad romance, ra ra fizz buzz"]
eval_dataset = Dataset.from_dict({'text': texts}).map(
lambda x: tokenizer(x['text']),
batched=True,
).map(
group_texts,
batched=True,
)
metric = evaluate.load("accuracy")
def compute_metrics(eval_preds):
preds, labels = eval_preds
# preds have the same shape as the labels, after the argmax(-1) has been calculated
# by preprocess_logits_for_metrics but we need to shift the labels
labels = labels[:, 1:].reshape(-1)
preds = preds[:, :-1].reshape(-1)
return metric.compute(predictions=preds, references=labels)
# Initialize our Trainer
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=default_data_collator,
compute_metrics=compute_metrics
)
trainer.train()
But do take a close look at the code on https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py to best understand what each part of the code is doing and if you want no-frills fine-tuning, use the CLI.