Search code examples
pythonpytorchpytorch-lightningckpt

Pytorch resume from checkpoint in Pytorch Lightning


I'm running a pytroch model that has already been written. In that code authors save checkpoints (ckpt) files each epoch. But the authors didn't give any option to resume the training from one of this checkpoints. Here is the authors original code.

trainer = pl.Trainer.from_argparse_args(args,
                                            default_root_dir=args.logdir,
                                            gpus = args.gpus,
                                            accelerator='ddp',
                                            sync_batchnorm=True,
                                            plugins=DDPPlugin(find_unused_parameters=False),
                                            profiler='simple',
                                            benchmark=True,
                                            log_every_n_steps=1,
                                            flush_logs_every_n_steps=5,
                                            callbacks=[checkpoint_callback,
                                                        ],
                                            check_val_every_n_epoch = args.val_every,
                                            max_epochs = args.epochs,
                                            logger=logger
                                            )

So I changed the above code to start running from a given checkpoint using command line arguments. Here is what I have done.

if args.loadfromcheckpoint>0:
        trainer = pl.Trainer(
            resume_from_checkpoint=args.logdir+"/epoch={checkpoint}-last.ckpt".format(checkpoint=args.loadfromcheckpoint),
            default_root_dir=args.logdir,
            gpus = args.gpus,
            accelerator='ddp',
            sync_batchnorm=True,
            plugins=DDPPlugin(find_unused_parameters=False),
            profiler='simple',
            benchmark=True,
            log_every_n_steps=1,
            flush_logs_every_n_steps=5,
            callbacks=[checkpoint_callback,
                        ],
            check_val_every_n_epoch = args.val_every,
            max_epochs = args.epochs,
            logger=logger)
        trainer.fit(TCP_model, dataloader_train, dataloader_val)
    else:
        trainer.fit(TCP_model, dataloader_train, dataloader_val)

The above code works fine. Since I'm quite new to Pytorch and Pytorch Lightning I have following questions,

  1. Does the lightning API only restore state_dict or does it restore all such as optimzer_states, lr_schedulers as well.
  2. If lightning doesn't load all those, how to load those states manually.

Solution

  • The lightning API will load everything - the entire training state at a particular epoch, the model's state_dict, optimizer's and scheduler's state_dict if you use resume_from_checkpoint. If you just want to do quick evaluation by only using model's state_dict, use load_from_checkpoint