Search code examples
pythonpytorchtorch

How to load and use a pretained PyTorch InceptionV3 model to classify an image


I have the same problem as How can I load and use a PyTorch (.pth.tar) model which does not have an accepted answer or one I can figure out how to follow the advice given.

I'm new to PyTorch. I am trying to load the pretrained PyTorch model referenced here: https://github.com/macaodha/inat_comp_2018

I'm pretty sure I am missing some glue.

# load the model
import torch
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')

# try to get it to classify an image
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])

def image_loader(image_name):
    """load image, returns cuda tensor"""
    image = Image.open(image_name)
    image = loader(image).float()
    image = Variable(image, requires_grad=True)
    image = image.unsqueeze(0)  
    return image.cpu()  #assumes that you're using CPU

image = image_loader("test-image.jpg")

Produces the error:

in () ----> 1 model.predict(image)

AttributeError: 'dict' object has no attribute 'predict


Solution

  • Problem

    Your model isn't actually a model. When it is saved, it contains not only the parameters, but also other information about the model as a form somewhat similar to a dict.

    Therefore, torch.load("iNat_2018_InceptionV3.pth.tar") simply returns dict, which of course does not have an attribute called predict.

    model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
    type(model)
    # dict
    

    Solution

    What you need to do first in this case, and in general cases, is to instantiate your desired model class, as per the official guide "Load models".

    # First try
    from torchvision.models import Inception3
    v3 = Inception3()
    v3.load_state_dict(model['state_dict']) # model that was imported in your code.
    

    However, directly inputing the model['state_dict'] will raise some errors regarding mismatching shapes of Inception3's parameters.

    It is important to know what was changed to the Inception3 after its instantiation. Luckily, you can find that in the original author's train_inat.py.

    # What the author has done
    model = inception_v3(pretrained=True)
    model.fc = nn.Linear(2048, args.num_classes) #where args.num_classes = 8142
    model.aux_logits = False
    

    Now that we know what to change, lets make some modification to our first try.

    # Second try
    from torchvision.models import Inception3
    v3 = Inception3()
    v3.fc = nn.Linear(2048, 8142)
    v3.aux_logits = False
    v3.load_state_dict(model['state_dict']) # model that was imported in your code.
    

    And there you go with successfully loaded model!