Search code examples
pytorchtransformer-modelhuggingface-transformers

How to avoid iterating over Dataloader while resuming training in Huggingface Trainer class?


I'm currently using Huggingface's Trainer class to train Distillbert for a regression problem using a custom loss function. I'm using their checkpoints to resume training due to the ephemeral nature of compute / unexpected errors.

The issue I'm facing is that each time I resume training from a checkpoint as per their Trainer class via the model_path in the Trainer.train() method, I noticed that the class iterates over the dataloader until it reaches the iteration count as saved in the checkpoint (see the lines from the Trainer class that match the issue).

This might usually not be a issue, but due to the nature of my dataloader's collate function and the size of the dataset, iterating for such a duration without any training is pretty expensive and slows down the overall training.

I planned on utilizing a custom sampler class something along the lines of this with a parameter to resume the indices from a given location but that too seems quite the hack for the given problem.

What could be an alternative that I could try to save on this wasted compute cycles?


Solution

  • Well it looks like huggingface has provided a solution to this via the use of ignore_data_skip argument in the TrainingArguments.

    Although you would have to be careful using this flag. It will essentially be as if you're starting a new epoch from step 0. But you'd be moving the optimizer / model state to whatever it was from the resume point.