I currently have my trainer set up as:
training_args = TrainingArguments(
output_dir=f"./results_{model_checkpoint}",
evaluation_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=2,
weight_decay=0.01,
push_to_hub=True,
save_total_limit = 1,
resume_from_checkpoint=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_qa["train"],
eval_dataset=tokenized_qa["validation"],
tokenizer=tokenizer,
data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
compute_metrics=compute_metrics
)
After training, in my output_dir
I have several files that the trainer saved:
['README.md',
'tokenizer.json',
'training_args.bin',
'.git',
'.gitignore',
'vocab.txt',
'config.json',
'checkpoint-5000',
'pytorch_model.bin',
'tokenizer_config.json',
'special_tokens_map.json',
'.gitattributes']
From the documentation it seems that resume_from_checkpoint
will continue training the model from the last checkpoint:
resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.
But when I call trainer.train()
it seems to delete the last checkpoint and start a new one:
Saving model checkpoint to ./results_distilbert-base-uncased/checkpoint-500
...
Deleting older checkpoint [results_distilbert-base-uncased/checkpoint-5000] due to args.save_total_limit
Does it really continue training from the last checkpoint (i.e., 5000) and just starts the count of the new checkpoint at 0 (saves the first after 500 steps -- "checkpoint-500"), or does it simply not continue the training? I haven't found a way to test it and the documentation is not clear on that.
Yes it works! When you call trainer.train() you're implicitly telling it to override all checkpoints and start from scratch. You should call trainer.train(resume_from_checkpoint=True) or set resume_from_checkpoint to a string pointing to the checkpoint path.