Search code examples
pythonpython-3.xmachine-learningpytorch

Related to SubsetRandomSampler


I am using SubsetRandomSampler for splitting a classification dataset into test and validation. Can we split the dataset for each class.

import numpy as np
import torch
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

train_transforms = transforms.Compose([transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])
dataset = datasets.ImageFolder( '/data/images/train', transform=train_transforms )

validation_split = .2
shuffle_dataset = True
random_seed= 42
batch_size = 20

dataset_size = len(dataset) #4996
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)

Solution

  • Did you mean train and validation not test and validation?

    If so, the SubsetRandomSampler uses randomly select samples from indices. Therefore you can just randomly split the indices of each class before put them in train_indices and val_indices.

    Like

    indexs = [[] for _ in range(len(dataset.classes))]  # you can't use `[[]] * len(dataset.classes)`. Although there might be better ways but I don't know
    for idx, (_, class_idx) in enumerate(dataset):
        indexs[class_idx].append(idx)
    train_indices, val_indices = [], []
    for cl_idx in indexs:
        size = len(cl_idx)
        split = int(np.floor(validation_split * size))
        np.random.shuffle(cl_idx)
        train_indices.extend(cl_idx[split:])
        val_indices.extend(cl_idx[:split])
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)