Search code examples
pytorchclassificationpredictionimage-classification

IndexError: list index out of range in prediction of images


I am doing predictions on images where I write all classes' names and in the test folder, I have 20 images. Please give me some hint as, why I am getting error? How we can check the indices of the model?

Code

import numpy as np
import sys, random
import torch
from torchvision import models, transforms
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import glob

# Paths for image directory and model
IMDIR = './test'
MODEL = 'checkpoint/resnet18/Monday_31_May_2021_21h_25m_05s/resnet18-1000-regular.pth'

# Load the model for testing
model = models.resnet18()

model.named_children()

torch.save(model.state_dict, MODEL)
model.eval()

# Class labels for prediction
class_names = ['BC', 'BK', 'CC', 'CL', 'CM', 'DF', 'DG', 'DS', 'HL', 'IF', 'JD', 'JS', 'LD', 'LP', 'LS', 'PO', 'RI',
               'SD', 'SG', 'TO']


# Retreive 9 random images from directory
files = Path(IMDIR).resolve().glob('*.*')
print(files)

images = random.sample(list(files), 1)
print(images)
# Configure plots
fig = plt.figure(figsize=(9, 9))
rows, cols = 3, 3

# Preprocessing transformations
preprocess = transforms.Compose([
    transforms.Resize((256, 256)),
    # transforms.CenterCrop(size=224),
    transforms.ToTensor(),
    transforms.Normalize(0.5306, 0.1348)
])

# Enable gpu mode, if cuda available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Perform prediction and plot results
with torch.no_grad():
    for num, img in enumerate(images):
        img = Image.open(img).convert('RGB')
        inputs = preprocess(img).unsqueeze(0).cpu()
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        print(preds)
        label = class_names[preds]
        plt.subplot(rows, cols, num + 1)
        plt.title("Pred: " + label)
        plt.axis('off')
        plt.imshow(img)
'''
Sample run: python test.py test
'''

Traceback

Traceback (most recent call last):
  File "/media/khawar/HDD_Khawar/CVPR/pytorch-cifar100/test_box.py", line 57, in <module>
    label = class_names[preds]
IndexError: list index out of range

Solution

  • Your error stems from the fact that you don't do any modification to the linear layers of your resnet model.

    I suggest adding this code:

    # What you have
    model = models.resnet18()
    
    # What you need
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, len(class_names)))
    
    

    This changes the last linear layers to outputting the correct amount of nodes

    Sarthak