Search code examples
pythontorch

TypeError: load_checkpoint() missing 1 required positional argument: 'ckpt_path'


Please help me with the following. I want to resume training from the checkpoint so I enter

python main.py --config cfgs/cifar10.yaml --resume checkpoint/cifar10/ckpt.pth.tar

in the console but it doesn't work. I get the error.

Traceback (most recent call last):   File "main.py", line 287, in <module>
    main()   File "main.py", line 85, in main
    start_epoch, best_nmi = load_checkpoint(model, dim_loss, optimizer, args.resume) TypeError: load_checkpoint() missing 1 required positional argument: 'ckpt_path'

Here are part of codes.

# argparser
parser = argparse.ArgumentParser(description='PyTorch Implementation of DCCM')
parser.add_argument('--resume', default=None, type=str, help='resume from a checkpoint')
parser.add_argument('--config', default='cfgs/config.yaml/', help='set sconfiguration file')
parser.add_argument('--small_bs', default=32, type=int)
parser.add_argument('--input_size', default=96, type=int)
parser.add_argument('--split', default=None, type=int, help='divide the large forward batch to avoid OOM')

resume training

if args.resume:
    logger.info("=> loading checkpoint '{}'".format(args.resume))
    start_epoch, best_nmi = load_checkpoint(model, dim_loss, optimizer, args.resume) #line85  

save and load checkpoint

 def load_checkpoint(model, dim_loss, classifier, optimizer, ckpt_path):

    checkpoint = torch.load(ckpt_path)
    
    model.load_state_dict(checkpoint['model'])
    dim_loss.load_state_dict(checkpoint['dim_loss'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    best_nmi = checkpoint['best_nmi']
    start_epoch = checkpoint['epoch']

    return start_epoch, best_nmi


def save_checkpoint(state, is_best_nmi, filename):
    torch.save(state, filename+'.pth.tar')
    if is_best_nmi:
        shutil.copyfile(filename+'.pth.tar', filename+'_best_nmi.pth.tar')

thank you


Solution

  • When you define the load_checkpoint it takes 5 parameters, but in line 85 you are pasing only 4, and like none of them have a default value, all are required. When you pase the parameters as positional arguments, you are filling optimizer with the value of ckpt_path. You have 2 options: to take out classifier parameter of the definition of the function (becouse it is never use) or give it a value, default or at runtime.