Search code examples
pytorchhuggingface-transformershuggingface

Does Huggingface's "resume_from_checkpoint" work?


I currently have my trainer set up as:

training_args = TrainingArguments(
    output_dir=f"./results_{model_checkpoint}",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    weight_decay=0.01,
    push_to_hub=True,
    save_total_limit = 1,
    resume_from_checkpoint=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_qa["train"],
    eval_dataset=tokenized_qa["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics
)

After training, in my output_dir I have several files that the trainer saved:

['README.md',
 'tokenizer.json',
 'training_args.bin',
 '.git',
 '.gitignore',
 'vocab.txt',
 'config.json',
 'checkpoint-5000',
 'pytorch_model.bin',
 'tokenizer_config.json',
 'special_tokens_map.json',
 '.gitattributes']

From the documentation it seems that resume_from_checkpoint will continue training the model from the last checkpoint:

resume_from_checkpoint (str or bool, optional) — If a str, local path to a saved checkpoint as saved by a previous instance of Trainer. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of Trainer. If present, training will resume from the model/optimizer/scheduler states loaded here.

But when I call trainer.train() it seems to delete the last checkpoint and start a new one:

Saving model checkpoint to ./results_distilbert-base-uncased/checkpoint-500
...
Deleting older checkpoint [results_distilbert-base-uncased/checkpoint-5000] due to args.save_total_limit

Does it really continue training from the last checkpoint (i.e., 5000) and just starts the count of the new checkpoint at 0 (saves the first after 500 steps -- "checkpoint-500"), or does it simply not continue the training? I haven't found a way to test it and the documentation is not clear on that.


Solution

  • Yes it works! When you call trainer.train() you're implicitly telling it to override all checkpoints and start from scratch. You should call trainer.train(resume_from_checkpoint=True) or set resume_from_checkpoint to a string pointing to the checkpoint path.