Search code examples
pythonmachine-learningpytorch

Creating a train, test split for data nested in multiple folders


I am preparing my data for training an image recognition model. I currently have one folder (the dataset) that contains multiple folders with the names of the labels and these folders have the images inside them.

I want to somehow split this dataset so that I have two main folders with the same subfolders, but the number of images inside these folders to be according to a preferred train/test split, so for instance 90% of the images in the train dataset and 10% in the test dataset.

I am struggling with finding the best way how to split my data. I have read a suggestion that pytorch torch.utils.Dataset class might be a way to do it but I can't seem to get it working as to preserve the folder hierarchy.


Solution

  • If you have a folder structure like this:

    folder
    │     
    │
    └───class1
    │   │   file011
    │   │   file012
    │   
    └───class2
        │   file021
        │   file022
    

    You can use simply the class torchvision.datasets.ImageFolder

    As stated from the website of pytorch

    A generic data loader where the images are arranged in this way:

    root/dog/xxx.png
    root/dog/xxy.png
    root/dog/xxz.png
    
    root/cat/123.png
    root/cat/nsdf3.png
    root/cat/asd932_.png
    

    Then, after you have created your ImageFolder instance, like this for example

    dataset = torchvision.datasets.Imagefolder(YOUR_PATH, ...)
    

    you can split it in this way:

    test_size = 0.1 * len(dataset)
    test_set = torch.utils.data.Subset(dataset, range(test_size))  # take 10% for test
    train_set = torch.utils.data.Subset(dataset, range(test_size, len(dataset)) # the last part for train
    

    If you want to make a shuffle of the split, remember that the class subset uses the indexes for the split. So you can shuffle, and split them. Doing something like this

    indexes = shuffle(range(len(dataset)))
    indexes_train = indexes[:int(len(dataset)*0.9)]
    indexes_test = = indexes[int(len(dataset)*0.9):]