I want to use one of the image augmentation techniques (for example rotation or horizontal flip) and apply it to some images of the CIFAR-10 dataset and plot them in PyTorch.
I know that we can use the following code to augmented images:
from torchvision import models, datasets, transforms
from torchvision.datasets import CIFAR10
data_transforms = transforms.Compose([
# add augmentations
transforms.RandomHorizontalFlip(p=0.5),
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1]
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
and then I used the transforms above when I want to load the Cifar10 dataset:
train_set = CIFAR10(
root='./data/',
train=True,
download=True,
transform=data_transforms['train'])
As far as I know, when this code is used, all CIFAR10 datasets are transformed.
Question
My question is how can I use data transform or augmentation techniques for some images in data sets and plot them? for example 10 images and their augmented images.
when this code is used, all CIFAR10 datasets are transformed
Actually, the transform pipeline will only be called when images in the dataset are fetched via the __getitem__
function by the user or through a data loader. So at this point in time, train_set
doesn't contain augmented images, they are transformed on the fly.
You will need to construct another dataset without augmentations.
>>> non_augmented = CIFAR10(
... root='./data/',
... train=True,
... download=True)
>>> train_set = CIFAR10(
... root='./data/',
... train=True,
... download=True,
... transform=data_transforms)
Stack some images together:
>>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)],
*[train_set[i][0] for i in range(10)]))
>>> imgs.shape
torch.Size([20, 3, 32, 32])
Then torchvision.utils.make_grid
can be useful to create the desired layout:
>>> grid = torchvision.utils.make_grid(imgs, nrow=10)
There you have it!
>>> transforms.ToPILImage()(grid)