I'm fine tuning a gpt-2 model following this tutorial:
With its associated GitHub repository:
https://github.com/nshepperd/gpt-2
I have been able to replicate the examples, my issue is that I'm not finding a parameter to set the number of iterations. Basically the training script shows a sample every 100 iterations and save a model version every 1000 iterations. But I'm not finding a parameter to train it for say, 5000 iterations and then close it.
The script for training is here: https://github.com/nshepperd/gpt-2/blob/finetuning/train.py
EDIT:
As suggested by cronoik I'm trying to replace the while for a for loop.
I'm adding these changes:
Adding one additional argument:
parser.add_argument('--training_steps', metavar='STEPS', type=int, default=1000, help='a number representing how many training steps the model shall be trained for')
Changing the loop:
try:
for iter_count in range(training_steps):
if counter % args.save_every == 0:
save()
Using the new argument:
python3 train.py --training_steps 300
But I'm getting this error:
File "train.py", line 259, in main
for iter_count in range(training_steps):
NameError: name 'training_steps' is not defined
All you have to do is to modify the while True
loop to a for
loop:
try:
#replaced
#while True:
for i in range(5000):
if counter % args.save_every == 0:
save()
if counter % args.sample_every == 0:
generate_samples()
if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
validation()
if args.accumulate_gradients > 1:
sess.run(opt_reset)
for _ in range(args.accumulate_gradients):
sess.run(
opt_compute, feed_dict={context: sample_batch()})
(v_loss, v_summary) = sess.run((opt_apply, summaries))
else:
(_, v_loss, v_summary) = sess.run(
(opt_apply, loss, summaries),
feed_dict={context: sample_batch()})
summary_log.add_summary(v_summary, counter)
avg_loss = (avg_loss[0] * 0.99 + v_loss,
avg_loss[1] * 0.99 + 1.0)
print(
'[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
.format(
counter=counter,
time=time.time() - start_time,
loss=v_loss,
avg=avg_loss[0] / avg_loss[1]))
counter += 1
except KeyboardInterrupt:
print('interrupted')
save()