Search code examples
pythonpython-3.xpytorchdataloader

Error in Transforming custom Dataset in Pytorch


I was following this tutorial: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html for making my own custom data loader for MuNuSeg Dataset but I am stuck a point. The dataloader is working fine but when I add transforms to it, I get errors.

I am facing an issue similar to what is mentioned here : Error Utilizing Pytorch Transforms and Custom Dataset
According to the answer there I have made custom transform for each of the inbuilt transforms so that one whole sample is transformed at the same time. Below are my custom transforms

class AffineTrans(object):
def __init__(self, degrees, translate):
    self.degrees = degrees
    self.translate = translate

def __call__(self, sample):
    image, contour, clrmask = sample['image'], sample['contour'], sample['clrmask']
    TF = transforms.RandomAffine(degrees = self.degrees, translate=self.translate)
    image = TF(image)
    contour = TF(contour)
    clrmask = (clrmask)

class Flip(object):   
def __call__(self, sample):
    image, contour, clrmask = sample['image'], sample['contour'], sample['clrmask']
    TF1 = transforms.RandomHorizontalFlip()
    TF2 = transforms.RandomVerticalFlip()
    image = TF1(image)
    contour = TF1(contour)
    clrmask = TF1(clrmask)
    image = TF2(image)
    contour = TF2(contour)
    clrmask = TF2(clrmask)

class ClrJitter(object):
def __init__(self, brightness, contrast, saturation, hue):
    self.brightness = brightness
    self.contrast = contrast
    self.saturation = saturation
    self.hue = hue

def __call__(self, sample):
    image, contour, clrmask = sample['image'], sample['contour'], sample['clrmask']
    TF = transforms.ColorJitter(self.brightness, self.contrast, self.saturation, self.hue)
    image = TF(image)
    contour = TF(contour)
    clrmask = TF(clrmask)

And composed them in the following way

composed = transforms.Compose([RandomCrop(256),
                           AffineTrans(15.0,(0.1,0.1)),
                           Flip(),    
                           ClrJitter(10, 10, 10, 0.01),
                           ToTensor()])

And here's the trainLoader Code

class trainLoader(Dataset):
def __init__(self, transform=None):
    """
    Args:
        transform (callable, optional): Optional transform to be applied
            on a sample.
    """
    [self.train_data , self.test_data1, self.test_data2] = dirload()
    self.transform = transform

def __len__(self):
    return(len(self.train_data[0]))

def __getitem__(self, idx):
    if torch.is_tensor(idx):
        idx = idx.tolist()
    img_name = self.train_data[0][idx]
    contour_name = self.train_data[1][idx]
    color_name = self.train_data[2][idx]
    image = cv2.imread(img_name)
    contour = cv2.imread(contour_name)
    clrmask = cv2.imread(color_name)
    sample = {'image': image, 'contour': contour, 'clrmask': clrmask}

    if self.transform:
        sample = self.transform(sample)

    return sample

To check the working of the above code I am doing the following

    train_dat = trainLoader(composed)

for i in range(len(train_dat)):
    sample = train_dat[i]

    print(i, sample['image'].shape, sample['contour'].shape, sample['clrmask'].shape)
    cv2.imshow('sample',sample['image'])
    cv2.waitKey()
    if i == 3:
        break

But I am still, again and again, Encountering the following error

runcell(0, 'F:/Moodle/SEM 8/SRE/code/MoNuSeg/main.py')

runcell(1, 'F:/Moodle/SEM 8/SRE/code/MoNuSeg/main.py')
Traceback (most recent call last):

  File "F:\Moodle\SEM 8\SRE\code\MoNuSeg\main.py", line 212, in <module>
    sample = train_dat[i]

  File "F:\Moodle\SEM 8\SRE\code\MoNuSeg\main.py", line 109, in __getitem__
    sample = self.transform(sample)

  File "F:\Moodle\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py", line 60, in __call__
    img = t(img)

  File "F:\Moodle\SEM 8\SRE\code\MoNuSeg\main.py", line 156, in __call__
    image = TF(image)

  File "F:\Moodle\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py", line 1018, in __call__
    ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)

  File "F:\Moodle\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py", line 992, in get_params
    max_dx = translate[0] * img_size[0]

TypeError: 'int' object is not subscriptable

This is quite an ambiguous error as I exactly don't get what is the error is

Any help would be really appreciated


Solution

  • The problem is that you're passing a NumPy array, whereas the transform expects a PIL Image. You can fix that by adding transforms.ToPILImage() as the first transform:

    composed = transforms.Compose([
        transforms.ToPILImage(),
        RandomCrop(256),
        AffineTrans(15.0,(0.1,0.1)),
        Flip(),    
        ClrJitter(10, 10, 10, 0.01),
        ToTensor()
    ])
    

    Assuming you have a from torchvision import transforms at the beginning.

    The root of the problem is that you're using OpenCV to load the images:

    def __getitem__(self, idx):
        # [...]
        image = cv2.imread(img_name)
    

    and you could replace these loading calls from OpenCV to PIL to solve the problem as well.


    Just so you know, for NumPy array, the .size() returns an int, which is causing you the problem. Check in the following code the difference:

    import numpy as np
    from PIL import Image
    
    # NumPy
    img = np.zeros((30, 30))
    print(img.size)  # output: 900
    
    # PIL
    pil_img = Image.fromarray(img)
    print(pil_img.size)  # output: (30, 30)