Search code examples
pythonpytorchtransform

pic should be Tensor or ndarray. Got <class ‘NoneType’>


I am a beginner in PyTorch. I want to train a network using NYU dataset, but I am getting an error.

enter image description here

The error happens while I use the Dataloader to load my local dataset, and I want to print the data to demonstrate the code is right:

test=Mydataset(data_root,transforms,'image_train')
test2=DataLoader(test,batch_size=4,num_workers=0,shuffle=False)
for idx,data in enumerate(test2):
  print(idx)

Here's the rest of the code with the Mydataset definition:

from __future__ import division,absolute_import,print_function
from PIL import Image
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import transforms
data_root='D:/AuxiliaryDocuments/NYU/'
transforms=transforms.Compose([transforms.ToPILImage(),
                           transforms.Resize(224,101),
                           transforms.ToTensor()])

filename_txt={'image_train':'image_train.txt','image_test':'image_test.txt',
          'depth_train':'depth_train.txt','depth_test':'depth_test.txt'}


class Mydataset(Dataset):
  def __init__(self,data_root,transformation,data_type):
    self.transform=transformation
    self.image_path_txt=filename_txt[data_type]
    self.sample_list=list()
    f=open(data_root+'/'+data_type+'/'+self.image_path_txt)
    lines=f.readlines()
    for line in lines:
        line=line.strip()
        line=line.replace(';','')
        self.sample_list.append(line)
    f.close()

def __getitem__(self, index):
    item=self.sample_list[index]
    img=Image.open(item)
    if self.transform is not None:
        img=self.transform(img)
    idx=index
    return idx,img

def __len__(self):
    return len(self.sample_list)

Solution

  • The error in the title is different from the one in the image (which you should have posted as text, by the way). Assuming the one from the image is correct, your problem is the following:

    Your transforms begins with a transforms.ToPILImage(), but the image is already read as a PIL image by the dataloader. If you remove that transformation, the code should run just fine.

    # [...]
    transforms = transforms.Compose([
        transforms.ToPILImage(),  # <<< remove this
        transforms.Resize(224, 101),
        transforms.ToTensor()
    ])
    
    # [...]
    
    class Mydataset(Dataset):
        # [...]
        def __getitem__(self, index):
            item = self.sample_list[index]
            img = Image.open(item)  # <<< this image is already a PIL image
            if self.transform is not None:
                img = self.transform(img)
            idx = index
            return idx, img
        # [...]