Search code examples
python-3.xpytorchpredictpytorch-dataloaderimageloader

I can predict one image but not a set of images with a pytorch resnet18 model, how can i predict a set of images in a list using pytorch models?


x is a list of (36, 60, 3) images. I am trying to predict with a pytorch pretrained resnet18 the output on my images. I took x as a list of 2 images. when I take only 1 image, i get the prediction with no errors as it follows:

im = x[0]
preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])
# Pass the image for preprocessing and the image preprocessed
img_preprocessed = preprocess(im)
# Reshape, crop, and normalize the input tensor for feeding into network for evaluation
batch_img_tensor = torch.unsqueeze(img_preprocessed, 0)
resnet18.eval()
out = resnet18(batch_img_tensor).flatten()

but it does not work when i set im=x. Something goes wrong in preprocessing line and I get this error:

TypeError: pic should be PIL Image or ndarray. Got <class 'list'>

I tried Variable (torch.tensot(x)) as follows :

x=dataset(source_p)
y=Variable(torch.tensor(x))
print(y.shape)
resnet18(y)

I get the following error :

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[2, 36, 60, 3] to have 3 channels, but got 36 channels instead

My question is : how can I predict all images in x list at once?

Thanks!


Solution

  • Eventually I created a class that takes x and transforms all elements :

    class formDataset(Dataset):
    
        def __init__(self, imgs, transform=None):
           
            self.imgs = imgs
            self.transform = transform
    
        def __len__(self):
            return len(self.imgs)
    
        def __getitem__(self, idx):
            if torch.is_tensor(idx):
                idx = idx.tolist()
            image = self.imgs[idx] 
            sample = {image}
    
            if self.transform:
                sample = self.transform(sample)
    
            return sample
    

    after I call

    l_set=formDataset(imgs=x,transform=preprocess)
    l_loader = DataLoader(l_set, batch_size=2)
    
    for data in (l_loader):
         features=resnet(outputs)