Search code examples
pytorchpytorch-lightninglearning-rate

Getting rid of the clutter of `.lr_find_` in pytorch lightning?


When using the Lightning’s built-in LR finder:

# Create a Tuner
tuner = Tuner(trainer)

# finds learning rate automatically
# sets hparams.lr or hparams.learning_rate to that learning rate
tuner.lr_find(model)

a lot of checkpoint lr_find_XXX.ckpt are created in the running directory which creates clutter. How can I make sure that these checkpoint are not created? Or keep them in a dedicated directory?


Solution

  • As it is defined in the lr_finder.py as:

    # Save initial model, that is loaded after learning rate is found
    ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
    trainer.save_checkpoint(ckpt_path)
    

    the initial model is saved with the checkpoint you are mentioning lr_find_XXX.ckpt to the directory trainer.default_root_dir. If no default directory is defined during the initialization of the trainer, current working directory will be assigned as the default_root_dir. After finding the ideal learning rate lr_find restores the initial model from the checkpoint and removes the checkpoint.

    # Restore initial state of model
    trainer._checkpoint_connector.restore(ckpt_path)
    trainer.strategy.remove_checkpoint(ckpt_path)
    

    You are probably stopping the program before the checkpoint is restored and removed so you have two options:

    1. Wait for the ideal learning rate to be found so that the checkpoint is removed
    2. Change the default_root_dir: Trainer(default_root_dir='./NAME_OF_THE_DIR') but be aware that this is also the directory that the lightning logs are saved to.