Search code examples
pythonpytorchdataset

torch.utils.data.random_split() is not splitting Dataset


I am using ImageFolder to load the data from a directory:

full_dataset = ImageFolder('some_dir', transform=transform)

When I print its length it gives: 32854. Now I want to split the Dataset returned by ImageFolder into train and test dataset using torch.utils.data.random_split(). I tried passing fraction [0.8, 0.2], and length like [len(full_dataset) - 100, 100].

train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [len(full_dataset) - 100, 100])

But when I print both their length using len(train_dataset.dataset.imgs) and len(test_dataset.dataset.imgs), they show the same value as full_dataset.

Why is my split not working?


Solution

  • You're referring to the original dataset (full_dataset in this case) when you do train_dataset.dataset (and test_dataset.dataset similarly). As a result, the imgs attribute on train_dataset.dataset (and test_dataset.dataset) would give you all the images belonging to the original dataset, not that of each split.

    As the Subset objects returned by random_split have the __len__ method (Subset technically is a subclass of the abstract class Dataset), you can get the length of each split/subset using len on them directly:

    len(train_dataset)
    len(test_dataset)