I am trying to create a pet LLM using GPT-2 following instructions here: https://thomascherickal.medium.com/how-to-create-your-own-llm-model-2598615a039a
The code gives syntax error while calling tf.compat.v1.estimator.Estimator() with model_fn as an argument:
NameError: name 'model_fn' is not defined
I tried defining model_fn as: model_fn = model_fn(hparams, tf.estimator.ModeKeys.TRAIN) but that did not help. I am not sure where model_fn should be defined.
Full code is here. Any help would be appreciated.
Tried different approaches but dont know how to define Model_fn
type here
import tensorflow as tf
import numpy as np
import os
import json
import random
import time
import argparse
# Define the command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, required=True,
help="Path to the dataset")
parser.add_argument("--model_path", type=str, required=True,
help="Path to the pre-trained model")
parser.add_argument("--output_path", type=str, required=True,
help="Path to save the fine-tuned model")
parser.add_argument("--batch_size", type=int, default=16,
help="Batch size for training")
parser.add_argument("--epochs", type=int, default=1,
help="Number of epochs to train for")
args = parser.parse_args()
# Load the pre-trained GPT-2 model
with open(os.path.join(args.model_path, "hparams.json"), "r") as f:
hparams = json.load(f)
model = tf.compat.v1.estimator.Estimator(
model_fn=model_fn, #<- error occurs here
model_dir=args.output_path,
params=hparams,
config=tf.compat.v1.estimator.RunConfig(
save_checkpoints_steps=5000,
keep_checkpoint_max=10,
save_summary_steps=5000
)
)
# Define the input function for the dataset
def input_fn(mode):
dataset = tf.data.TextLineDataset(args.dataset_path)
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(args.batch_size)
dataset = dataset.map(lambda x: tf.strings.substr(x, 0, hparams["n_ctx"]))
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
# Define the training function
def train():
for epoch in range(args.epochs):
model.train(input_fn=lambda: input_fn(tf.estimator.ModeKeys.TRAIN))
print(f"Epoch {epoch+1} completed.")
# Start the training
train()