Search code examples
pythontensorflownlpgpt-2

Set the number of iterations gpt-2


I'm fine tuning a gpt-2 model following this tutorial:

https://medium.com/@ngwaifoong92/beginners-guide-to-retrain-gpt-2-117m-to-generate-custom-text-content-8bb5363d8b7f

With its associated GitHub repository:

https://github.com/nshepperd/gpt-2

I have been able to replicate the examples, my issue is that I'm not finding a parameter to set the number of iterations. Basically the training script shows a sample every 100 iterations and save a model version every 1000 iterations. But I'm not finding a parameter to train it for say, 5000 iterations and then close it.

The script for training is here: https://github.com/nshepperd/gpt-2/blob/finetuning/train.py

EDIT:

As suggested by cronoik I'm trying to replace the while for a for loop.

I'm adding these changes:

  1. Adding one additional argument:

    parser.add_argument('--training_steps', metavar='STEPS', type=int, default=1000, help='a number representing how many training steps the model shall be trained for')

  2. Changing the loop:

     try:
         for iter_count in range(training_steps):
             if counter % args.save_every == 0:
                 save()
    
  3. Using the new argument:

    python3 train.py --training_steps 300

But I'm getting this error:

  File "train.py", line 259, in main
    for iter_count in range(training_steps):
NameError: name 'training_steps' is not defined

Solution

  • All you have to do is to modify the while True loop to a for loop:

    try:
        #replaced
        #while True:
        for i in range(5000):
            if counter % args.save_every == 0:
                save()
            if counter % args.sample_every == 0:
                generate_samples()
            if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                validation()
    
            if args.accumulate_gradients > 1:
                sess.run(opt_reset)
                for _ in range(args.accumulate_gradients):
                    sess.run(
                        opt_compute, feed_dict={context: sample_batch()})
                (v_loss, v_summary) = sess.run((opt_apply, summaries))
            else:
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summaries),
                    feed_dict={context: sample_batch()})
    
            summary_log.add_summary(v_summary, counter)
    
            avg_loss = (avg_loss[0] * 0.99 + v_loss,
                        avg_loss[1] * 0.99 + 1.0)
    
            print(
                '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_loss,
                    avg=avg_loss[0] / avg_loss[1]))
    
            counter += 1
    except KeyboardInterrupt:
        print('interrupted')
        save()