Search code examples
pythonpytorchpytorch-lightningcheckpointing

Get paths of saved checkpoints from Pytorch-Lightning ModelCheckpoint


I am using PytorchLightning and beside others a ModelCheckpoint which saves models with a formated filename like filename="model_{epoch}-{val_acc:.2f}"

I then want to load these checkpoints again, for simplicity I want the best from save_top_k=N. As the filename is dynamic I wonder how can I retrieve the checkpoint files easily.
Is there a built-in attribute in the ModelCheckpoint or the trainer that gives me the saved checkpoints? For example like

checkpoint_callback.get_top_k_paths()

I know I can do it with glob and model_dir but wondering if there is a one line solution built in somewhere.


Solution

  • -> All stored checkpoints can be found in ModelCheckpoint.best_k_models : Dict[str, Tensor] where the keys are the paths and the values the metric that is tracked.

    Additionally does ModelCheckpoint have these attributes: best_model_path best_model_score, kth_best_model_path, kth_value, last_model_path and best_k_models.


    Note: when loading a checkpoint

    These values are only guaranteed when model_checkpoint.dirpath matches the one in in the checkpoints_state_dict["dirpath"], i.e. you did not change the directory, otherwise only best_model_path is restored.

    Otherwise as Aniket Maurya states you have to look at dirpath or the parallel files in best_model_path.