Search code examples
pythonneural-networkpytorch

using random_split() in python to split the Trainset to train and validation


train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, download=True)
test_dataset  = torchvision.datasets.FashionMNIST(data_dir, train=False, download=True)

using two above lines I load the Mnist dataset and then transfered them to Tensor and Dataloader using below lines of code

tr =transforms.Compose([transforms.ToTensor(),])
train_dataset.transform = tr
test_dataset.transform = tr
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

and then by using the for loop such as bellow I iterate over the data and do training model in pytorch.

for i in train_dataloader:

But when I split the training data into two parts using random_split I get error using the for loop

train_dataset, val_dataset = random_split(train_dataset, (50000, 10000))

train_dataset.transform = tr
test_dataset.transform = tr
val_dataset.transform = tr

train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
validation_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

The error is:

default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

How can solve the issue?


Solution

  • You should pass transform to your FashionMNIST dataset's constructor directly.

    train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, download=True, transform=tr)
    test_dataset  = torchvision.datasets.FashionMNIST(data_dir, train=False, download=True, transform=tr)