0

I'm fine tuning a gpt-2 model following this tutorial:

https://medium.com/@ngwaifoong92/beginners-guide-to-retrain-gpt-2-117m-to-generate-custom-text-content-8bb5363d8b7f

With its associated GitHub repository:

https://github.com/nshepperd/gpt-2

I have been able to replicate the examples, my issue is that I'm not finding a parameter to set the number of iterations. Basically the training script shows a sample every 100 iterations and save a model version every 1000 iterations. But I'm not finding a parameter to train it for say, 5000 iterations and then close it.

The script for training is here: https://github.com/nshepperd/gpt-2/blob/finetuning/train.py

EDIT:

As suggested by cronoik I'm trying to replace the while for a for loop.

I'm adding these changes:

  1. Adding one additional argument:

    parser.add_argument('--training_steps', metavar='STEPS', type=int, default=1000, help='a number representing how many training steps the model shall be trained for')

  2. Changing the loop:

     try:
         for iter_count in range(training_steps):
             if counter % args.save_every == 0:
                 save()
    
  3. Using the new argument:

    python3 train.py --training_steps 300

But I'm getting this error:

  File "train.py", line 259, in main
    for iter_count in range(training_steps):
NameError: name 'training_steps' is not defined
Guy Coder
  • 24,501
  • 8
  • 71
  • 136
Luis Ramon Ramirez Rodriguez
  • 9,591
  • 27
  • 102
  • 181
  • 1
    It should be `for iter_count in range(args.training_steps)` and not `for iter_count in range(training_steps)` because you have added another parameter which is a member of `args`. – cronoik Sep 07 '19 at 20:26

1 Answers1

1

All you have to do is to modify the while True loop to a for loop:

try:
    #replaced
    #while True:
    for i in range(5000):
        if counter % args.save_every == 0:
            save()
        if counter % args.sample_every == 0:
            generate_samples()
        if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
            validation()

        if args.accumulate_gradients > 1:
            sess.run(opt_reset)
            for _ in range(args.accumulate_gradients):
                sess.run(
                    opt_compute, feed_dict={context: sample_batch()})
            (v_loss, v_summary) = sess.run((opt_apply, summaries))
        else:
            (_, v_loss, v_summary) = sess.run(
                (opt_apply, loss, summaries),
                feed_dict={context: sample_batch()})

        summary_log.add_summary(v_summary, counter)

        avg_loss = (avg_loss[0] * 0.99 + v_loss,
                    avg_loss[1] * 0.99 + 1.0)

        print(
            '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
            .format(
                counter=counter,
                time=time.time() - start_time,
                loss=v_loss,
                avg=avg_loss[0] / avg_loss[1]))

        counter += 1
except KeyboardInterrupt:
    print('interrupted')
    save()
cronoik
  • 15,434
  • 3
  • 40
  • 78