0

I am retraining the GPT2 language model, and am following this blog :

https://towardsdatascience.com/train-gpt-2-in-your-own-language-fc6ad4d60171

Here, they have trained a network on GPT2, and I am trying to recreate a same. However, my dataset is too large(250Mb), so I want to continue training in intervals. In other words, I want to checkpoint the model training. If there is any help, or a piece of code that I can implement to checkpoint and continue training, it would help a great deal for me. Thank you.

Vivek
  • 124
  • 14

1 Answers1

1
training_args = TrainingArguments(
    output_dir=model_checkpoint,
    # other hyper-params
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=dev_set,
    tokenizer=tokenizer
)

trainer.train()
# Save the model to model_dir
trainer.save_model()

def prepare_model(tokenizer, model_name_path):
    model = AutoModelForCausalLM.from_pretrained(model_name_path)
    model.resize_token_embeddings(len(tokenizer))
    return model

# Assume tokenizer is defined, You can simply pass the saved model directory path.
model = prepare_model(tokenizer, model_checkpoint)