0

Hey I’m trying to finetune Llama 2 and I can’t see where the checkpoints are getting saved. I am using the following code:

output_dir = "./Llama-2-7b-hf-qlora"

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    logging_steps=5,
    max_steps=400,
    evaluation_strategy="steps", # Evaluate the model every logging step
    logging_dir="./logs",        # Directory for storing logs
    save_strategy="steps",       # Save the model checkpoint every logging step
    eval_steps=5,               # Evaluate and save checkpoints every 10 steps
    do_eval=True                 # Perform evaluation at the end of training
)
class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        kwargs["model"].save_pretrained(checkpoint_path)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))


callbacks = [PeftSavingCallback()]
max_seq_length = 512
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,  # Add this line
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_args,
    callbacks=callbacks
)
trainer.train()

(I added in the callback stuff based on this guide Supervised Fine-tuning Trainer) How do I get a checkpoint saved every 5/10 steps?

johnny
  • 51
  • 4

0 Answers0