Search code examples
pythonpytorchtorchvisiondata-augmentation

Does order of transforms applied for data augmentation matter in Torchvision transforms?


I have the following Custom dataset class for an image segmentation task.

class LoadDataset(Dataset):
    def __init__(self, img_dir, mask_dir, apply_transforms = None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transforms = apply_transforms
        self.img_paths, self.mask_paths = self.__get_all_paths()
        self.__pil_to_tensor = transforms.PILToTensor()
        self.__float_tensor = transforms.ToDtype(torch.float32, scale = True)
        self.__grayscale = transforms.Grayscale()

    def __get_all_paths(self):
        img_paths = [os.path.join(self.img_dir, img_name.name) for img_name in os.scandir(self.img_dir) if os.path.isfile(img_name)]
        mask_paths = [os.path.join(self.mask_dir, mask_name.name) for mask_name in os.scandir(self.mask_dir) if os.path.isfile(mask_name)]
        img_paths = sorted(img_paths)
        mask_paths = sorted(mask_paths)
        return img_paths, mask_paths

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        img_path, mask_path = self.img_paths[index], self.mask_paths[index]
        img_PIL = Image.open(img_path)
        mask_PIL = Image.open(mask_path)
        img_tensor = self.__pil_to_tensor(img_PIL)
        img_tensor = self.__float_tensor(img_tensor)
        mask_tensor = self.__pil_to_tensor(mask_PIL)
        mask_tensor = self.__float_tensor(mask_tensor)
        mask_tensor = self.__grayscale(mask_tensor)
        if self.transforms:
            img_tensor, mask_tensor = self.transforms(img_tensor, mask_tensor)
        return img_tensor, mask_tensor

When I am applying the following transformation:

transforms.RandomHorizontalFlip()

either the image or the mask is being flipped. But if I change the order of the transformations in __getitem__ to the following, then it works fine.

def __getitem__(self, index):
    img_path, mask_path = self.img_paths[index], self.mask_paths[index]
    img_PIL = Image.open(img_path)
    mask_PIL = Image.open(mask_path)
    if self.transforms:
        img_PIL, mask_PIL = self.transforms(img_PIL, mask_PIL)
    img_tensor = self.__pil_to_tensor(img_PIL)
    mask_tensor = self.__pil_to_tensor(mask_PIL)
    img_tensor = self.__float_tensor(img_tensor)
    mask_tensor = self.__float_tensor(mask_tensor)
    mask_tensor = self.__grayscale(mask_tensor)
    return img_tensor, mask_tensor

Does the order transformation matter? I'm using torchvision.transforms.v2 for all the transformations.


Solution

  • Yes, the order of transformations matters. In this case, the transform to tensors makes the difference. When v2.RandomHorizontalFlip is given two tensors, the flip will be applied independently. However, when two PIL images are given, the same transform will be applied to both images, thus keeping the image and mask aligned.

    For a more consistent handling, you can try using TVTensors for the data augmentation. Using these, you can specify the type of each data input before transforming them. For example:

    from torchvision import tv_tensors
    
    img_tensor = tv_tensors.Image(img_tensor)
    mask_tensor= tv_tensors.Mask(mask_tensor)