Search code examples
pythonpytorchartificial-intelligenceresnettorchvision

pytorch dataloader- RuntimeError: stack expects each tensor to be equal size


I'm trying resnet implementation from scratch.

This error occurred during training after setting the dataset in Dataloader:

RuntimeError: stack expects each tensor to be equal size, but got [3, 224, 224] at entry 0 and [1, 224, 224] at entry 25

Usually, I know why this message displayed. It's about unequal image dataset

But before I got this message, I apply transforms so I didn't expect error like this.

Sometimes the message [4,224,224] appears instead of [1,224,224].

Of course, I checked dataset all images are color

Here's the code

## Class declaration, transforms codes are applied __getitem__()

class cnd_data(torch.utils.data.Dataset):
    def __init__(self, file_path, train=True, transforms=None):

        self.train=train
        self.transforms=transforms

        self.cat_img_path=os.path.join(file_path, 'data\kagglecatsanddogs\PetImages\Cat')
        self.dog_img_path=os.path.join(file_path, 'data\kagglecatsanddogs\PetImages\Dog')
        
        self.cat_list=natsort.natsorted(glob.glob(self.cat_img_path + '/*.jpg'))
        self.dog_list=natsort.natsorted(glob.glob(self.dog_img_path + '/*.jpg'))
 

        if self.train==True:
            self.imgn_list=self.cat_list[:12000]+self.dog_list[:12000]
            self.img_label=[0]*12000+[1]*12000

        else:
            self.imgn_list=self.cat_list[12000:]+self.dog_list[12000:]
            self.img_label=[0]*500+[1]*500

    def __len__(self):
        return len(self.img_label)

    def __getitem__(self, idx):

        image_data=Image.open(self.imgn_list[idx])

        print(self.imgn_list[idx])
        if self.transforms:
            sample=self.transforms(image_data)

        return sample, self.img_label[idx]


## And transforms setting codes
transforms=transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor()])

Solution

  • Pytorch expects any data sample to has the same size after preprocessing, so that it can be stacked to a single tensor. According to your error log, [3, 224, 224] seems to refer to a RGB image; [1, 224, 224] is grayscale and [4, 224, 224] is RGBA (has alpha channel for opacity). If you use PIL Image, make sure to convert it to RGB before apply any transform.

    Image.open(path).convert('RGB')