I am working on Project 2 of a course with Udacity (Artificial Intelligence with Python Programming).
I have trained a model and saved it in checkpoint.pth and I want to load the checkpoint.pth so I can rebuild the model .
I have written the code to save checkpoint.pth and also to load checkpoint.
model.class_to_idx = image_datasets['train_dir'].class_to_idx
model.cpu()
checkpoint = {'input_size': 25088,
'output_size': 102,
'hidden_layers': 4096,
'epochs': epochs,
'optimizer': optimizer.state_dict(),
'state_dict': model.state_dict(),
'class_to_index' : model.class_to_idx
}
torch.save(checkpoint, 'checkpoint.pth')
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = checkpoint.Network(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['hidden_layers'],
checkpoint['epochs'],
checkpoint['optimizer'],
checkpoint['class_to_index']
)
model.load_state_dict(checkpoint['state_dict'])
return model
model = load_checkpoint('checkpoint.pth')
While loading checkpoint.pth, I get an error:
AttributeError: 'dict' object has no attribute 'Network'
I want successfully load checkpoint.
Thank you
UPDATE: With the full code visibile, I think the issues is in the implementation. torch.load will load the information from the dict that has been deserialized to the file. This loads as the original dict object, so in the function, you should expect checkpoint == checkpoint(original definition).
In this instance, I think what you are actually looking to do is calling the load on the file saved as checkpoint.pth
and the first call might not be necessary.
def load_checkpoint(filepath):
model = torch.load(filepath)
return model
The other possibility is that the nested object must be what the object is called, and then it would be just a small adjustment:
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = torch.load_state_dict(checkpoint['state_dict'])
return model
The most likely problem is that you are calling on the Network class, which is not contained within the checkpoint dictionary object.
I can't speak to the actual lesson or other nuances within the lesson, the simplest solution might be to just call the Network class definition with the variables already in the checkpoint dictionary like so:
model = Network(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['hidden_layers'],
checkpoint['epochs'],
checkpoint['optimizer'],
checkpoint['class_to_index'])
model.load_state_dict(checkpoint['state_dict'])
return model
The checkpoint dict may only have the values you expect ('input_size', 'output_size' etc) But this is just the most obvious issue I see.