Search code examples
pythonpytorchconv-neural-networktorch

How do I predict using a PyTorch model?


I created a pyTorch Model to classify images. I saved it once via state_dict and the entire model like that:

torch.save(model.state_dict(), "model1_statedict")
torch.save(model, "model1_complete")

How can i use these models? I'd like to check them with some images to see if they're good.

I am loading the model with:

model = torch.load(path_model)
model.eval()

This works alright, but i have no idea how to use it to predict on a new picture.


Solution

  • def predict(self, test_images):
        self.eval()
        # model is self(VGG class's object)
        
        count = test_images.shape[0]
        result_np = []
            
        for idx in range(0, count):
            # print(idx)
            img = test_images[idx, :, :, :]
            img = np.expand_dims(img, axis=0)
            img = torch.Tensor(img).permute(0, 3, 1, 2).to(device)
            # print(img.shape)
            pred = self(img)
            pred_np = pred.cpu().detach().numpy()
            for elem in pred_np:
                result_np.append(elem)
        return result_np
    

    network is VGG-19 and ref my source code.

    like this architecture:

    class VGG(object):
        def __init__(self):
        ...
    
    
        def train(self, train_images, valid_images):
            train_dataset = torch.utils.data.Dataset(train_images)
            valid_dataset = torch.utils.data.Dataset(valid_images)
    
            trainloader = torch.utils.data.DataLoader(train_dataset)
            validloader = torch.utils.data.DataLoader(valid_dataset)
    
            self.optimizer = Adam(...)
            self.criterion = CrossEntropyLoss(...)
        
            for epoch in range(0, epochs):
                ...
                self.evaluate(validloader, model=self, criterion=self.criterion)
        ...
    
        def evaluate(self, dataloader, model, criterion):
            model.eval()
            for i, sample in enumerate(dataloader):
        ...
    
        def predict(self, test_images):
        
        ...
    
    if __name__ == "__main__":
        network = VGG()
        trainset, validset = get_dataset()    # abstract function for showing
        testset = get_test_dataset()
        
        network.train(trainset, validset)
    
        result = network.predict(testset)