I'm running a pytroch model that has already been written. In that code authors save checkpoints (ckpt) files each epoch. But the authors didn't give any option to resume the training from one of this checkpoints. Here is the authors original code.
trainer = pl.Trainer.from_argparse_args(args,
default_root_dir=args.logdir,
gpus = args.gpus,
accelerator='ddp',
sync_batchnorm=True,
plugins=DDPPlugin(find_unused_parameters=False),
profiler='simple',
benchmark=True,
log_every_n_steps=1,
flush_logs_every_n_steps=5,
callbacks=[checkpoint_callback,
],
check_val_every_n_epoch = args.val_every,
max_epochs = args.epochs,
logger=logger
)
So I changed the above code to start running from a given checkpoint using command line arguments. Here is what I have done.
if args.loadfromcheckpoint>0:
trainer = pl.Trainer(
resume_from_checkpoint=args.logdir+"/epoch={checkpoint}-last.ckpt".format(checkpoint=args.loadfromcheckpoint),
default_root_dir=args.logdir,
gpus = args.gpus,
accelerator='ddp',
sync_batchnorm=True,
plugins=DDPPlugin(find_unused_parameters=False),
profiler='simple',
benchmark=True,
log_every_n_steps=1,
flush_logs_every_n_steps=5,
callbacks=[checkpoint_callback,
],
check_val_every_n_epoch = args.val_every,
max_epochs = args.epochs,
logger=logger)
trainer.fit(TCP_model, dataloader_train, dataloader_val)
else:
trainer.fit(TCP_model, dataloader_train, dataloader_val)
The above code works fine. Since I'm quite new to Pytorch and Pytorch Lightning I have following questions,
The lightning API will load everything - the entire training state at a particular epoch, the model's state_dict, optimizer's and scheduler's state_dict if you use resume_from_checkpoint
. If you just want to do quick evaluation by only using model's state_dict, use load_from_checkpoint