I have the following Custom dataset class for an image segmentation task.
class LoadDataset(Dataset):
def __init__(self, img_dir, mask_dir, apply_transforms = None):
self.img_dir = img_dir
self.mask_dir = mask_dir
self.transforms = apply_transforms
self.img_paths, self.mask_paths = self.__get_all_paths()
self.__pil_to_tensor = transforms.PILToTensor()
self.__float_tensor = transforms.ToDtype(torch.float32, scale = True)
self.__grayscale = transforms.Grayscale()
def __get_all_paths(self):
img_paths = [os.path.join(self.img_dir, img_name.name) for img_name in os.scandir(self.img_dir) if os.path.isfile(img_name)]
mask_paths = [os.path.join(self.mask_dir, mask_name.name) for mask_name in os.scandir(self.mask_dir) if os.path.isfile(mask_name)]
img_paths = sorted(img_paths)
mask_paths = sorted(mask_paths)
return img_paths, mask_paths
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index):
img_path, mask_path = self.img_paths[index], self.mask_paths[index]
img_PIL = Image.open(img_path)
mask_PIL = Image.open(mask_path)
img_tensor = self.__pil_to_tensor(img_PIL)
img_tensor = self.__float_tensor(img_tensor)
mask_tensor = self.__pil_to_tensor(mask_PIL)
mask_tensor = self.__float_tensor(mask_tensor)
mask_tensor = self.__grayscale(mask_tensor)
if self.transforms:
img_tensor, mask_tensor = self.transforms(img_tensor, mask_tensor)
return img_tensor, mask_tensor
When I am applying the following transformation:
transforms.RandomHorizontalFlip()
either the image or the mask is being flipped. But if I change the order of the transformations in __getitem__
to the following, then it works fine.
def __getitem__(self, index):
img_path, mask_path = self.img_paths[index], self.mask_paths[index]
img_PIL = Image.open(img_path)
mask_PIL = Image.open(mask_path)
if self.transforms:
img_PIL, mask_PIL = self.transforms(img_PIL, mask_PIL)
img_tensor = self.__pil_to_tensor(img_PIL)
mask_tensor = self.__pil_to_tensor(mask_PIL)
img_tensor = self.__float_tensor(img_tensor)
mask_tensor = self.__float_tensor(mask_tensor)
mask_tensor = self.__grayscale(mask_tensor)
return img_tensor, mask_tensor
Does the order transformation matter? I'm using torchvision.transforms.v2
for all the transformations.
Yes, the order of transformations matters. In this case, the transform to tensors makes the difference. When v2.RandomHorizontalFlip
is given two tensors, the flip will be applied independently. However, when two PIL images are given, the same transform will be applied to both images, thus keeping the image and mask aligned.
For a more consistent handling, you can try using TVTensors
for the data augmentation. Using these, you can specify the type of each data input before transforming them. For example:
from torchvision import tv_tensors
img_tensor = tv_tensors.Image(img_tensor)
mask_tensor= tv_tensors.Mask(mask_tensor)