Search code examples
deep-learningconv-neural-networkpytorchtorchdata-augmentation

How do I augment data after spliting traininng datset into train and validation set for CIFAR10 using PyTorch?


When classifying the CIFAR10 in PyTorch, there are normally 50,000 training samples and 10,000 testing samples. However, if I need to create a validation set, I can do it by splitting the training set into 40000 train samples and 10000 validation samples. I used the following codes

train_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

cifar_train_L = CIFAR10('./data',download=True, train= True, transform = train_transform)
cifar_test = CIFAR10('./data',download=True, train = False, transform= test_transform) 

train_size = int(0.8*len(cifar_training))
val_size = len(cifar_training) - train_size
cifar_train, cifar_val = torch.utils.data.random_split(cifar_train_L,[train_size,val_size])

train_dataloader = torch.utils.data.DataLoader(cifar_train, batch_size= BATCH_SIZE, shuffle= True, num_workers=2)
test_dataloader = torch.utils.data.DataLoader(cifar_test,batch_size= BATCH_SIZE, shuffle= True, num_workers= 2)
val_dataloader = torch.utils.data.DataLoader(cifar_val,batch_size= BATCH_SIZE, shuffle= True, num_workers= 2)

Normally, when augmenting data in PyTorch, different augmenting processes are used under the transforms.Compose function (i.e., transforms.RandomHorizontalFlip()). However, if I use these augmentation processes before splitting the training set and validation set, the augmented data will also be included in the validation set. Is there any way, I can fix this problem?

In short, I want to manually split the training dataset into train and validation set as well as I want to use the data augmentation technique into the new training set.


Solution

  • You can manually override the transforms of the dataset:

    cifar_train, cifar_val = torch.utils.data.random_split(cifar_train_L,[train_size,val_size])
    cifar_val.transforms = test_transform