Search code examples
pythonvalidationpytorchmnistdataloader

Splitting custom PyTorch dataset into train loader and validation loader: Length of both same, even though dataset was split?


I'm trying to split one of the Pytorch custom datasets (MNIST) into a training set and a validation set as follows:

def get_train_valid_splits(data_dir,
                           batch_size,
                           random_seed=1,
                           valid_size=0.2,
                           shuffle=True,
                           num_workers=4,
                           pin_memory=False):

    normalize = transforms.Normalize((0.1307,), (0.3081,))  # MNIST

    # define transforms
    valid_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

        train_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])

    # load the dataset
    train_dataset = datasets.MNIST(root=data_dir, train=True,
                download=True, transform=train_transform)

    valid_dataset = datasets.MNIST(root=data_dir, train=True,
                download=True, transform=valid_transform)

    dataset_size = len(train_dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(valid_size * dataset_size))

    
    if shuffle == True:
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    

    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = sampler.SubsetRandomSampler(train_idx)
    valid_sampler = sampler.SubsetRandomSampler(valid_idx)

    print(len(train_sampler))
    print(len(valid_sampler))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                    batch_size=batch_size, sampler=train_sampler,
                    num_workers=num_workers, pin_memory=pin_memory)

    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                    batch_size=batch_size, sampler=valid_sampler,
                    num_workers=num_workers, pin_memory=pin_memory)

    print(len(train_loader.dataset))
    print(len(valid_loader.dataset))

    return (train_loader, valid_loader)

After calling the function I notice that the results of the indices to sample look right, 48000 and 12000:

print(len(train_sampler))
print(len(valid_sampler))

But when I look at the length of the data set associated with train_loader and valid_loader:

print(len(train_loader.dataset))
print(len(valid_loader.dataset))

I get the same length for both: 60000! Any idea what is going on here? Why is it giving the same length for both, even though I clearly split it by indices?


Solution

  • It's because the dataloader doesn't modify the dataset you pass it, but "applies" things like batch size, samplers, etc ... to the data when you try to access by iterating it. Your issue is len(loader.dataset), which gives you the length of the provided dataset without modification, when you really wanted len(loader) which is the length of the dataset after "applying" things like batch size and samplers.

    import torch
    import numpy as np
    
    dataset = np.random.rand(100,200)
    sampler = torch.utils.data.SubsetRandomSampler(list(range(70)))
    
    loader = torch.utils.data.DataLoader(dataset, sampler=sampler)
    print(len(loader)) 
    >>> 70
    print(len(loader.dataset))
    >>> 100
    

    Note: The result of len will be affected by batch size:

    # with batch size
    loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=2)
    print(len(loader)) 
    >>> 35
    print(len(loader.dataset))
    >>> 100