Search code examples
pythonpytorchdatasetpytorch-dataloaderdata-augmentation

Can't apply same transform to image and mask for data-augmentation


I'm trying to train a U-Net model build with pytorch. For that case, I built the dataset and applied transformations for data augmentation in both image and mask. The situation is that i want to apply the same transformation to both, that meaning, if I rotate the image by an amount of degrees I want the mask to be rotated the same amount of degrees and therein lies my problem. The image and the mask aren´t rotated by the same amount.

I leave the code bellow:

Dataset

import torch
from torch.utils.data import Dataset
import os

class INBreastDataset2012(Dataset):
    def __init__(self, dict_dir, transform=None):
        self.dict_dir = dict_dir
        self.data = os.listdir(self.dict_dir)
        self.transform = transform



    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        dict_path = os.path.join(self.dict_dir, self.data[index])
        patient_dict = torch.load(dict_path)
        image = patient_dict['image'].unsqueeze(0)
        mass_mask = patient_dict['mass_mask'].unsqueeze(0)
        mass_mask[mass_mask > 1.0] = 1.0


        if self.transform is not None:
            image = self.transform(image)
            mass_mask = self.transform(mass_mask)
            
        
        return image, mass_mask


"Trainging"(isn't really training at this point, just visualization of the information brought by the dataloader)

from dataset import INBreastDataset2012
from torchvision.transforms import v2 as T
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

train_dir = r'directory\of\training images and masks'
test_dir = r'directory\of\testing images and masks'

train_transform = T.Compose(
        [
            T.RandomRotation(degrees=35, expand=True, fill=255.0),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),

        ]
    )

train_data = INBreastDataset2012(train_dir,transform=train_transform)
test_data = INBreastDataset2012(test_dir)

train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

plt.figure(figsize=(12,12))
for i, (imagen,mascara) in enumerate(train_dataloader):
    ax = plt.subplot(2,4,i+1)
    ax.title.set_text(f'imagen {i+1}')
    plt.imshow(imagen.squeeze(), cmap='gray')
    ax = plt.subplot(2,4,i+3)
    ax.title.set_text(f'mascara de imagen {i+1}')
    plt.imshow(mascara.squeeze(), cmap='gray')
    if i == 1:
        break

Result Result transformation of images and masks

I will also add that I've tried with albumentations and torchvision.transforms v1. In examples of pytorch and youtube videos they seem to be doing the same as me.

I someone could help me to see what I'm doing wrong or have a solution to ensuring that the transformations are the same is going to be greatly appreciated.

If any extra information is needen please ask. Is my first post so I may have missed something. Thank you in advance


Solution

  • You could try concatenating the image and mask along the channel dimension, running the transform, and then splitting the result back into two tensors. Below assumes the image and mask are shaped channels x height x width.

    ...
    
    if self.transform is not None:
        #Concatenate along channel dimension.
        # Assuming dim=0 is the channel dimension (not the batch dim)
        image_and_mask = torch.cat([image, mask], dim=0) 
     
        #Transform together
        transformed = self.transform(image_and_mask)
        
        #Slice the tensors out
        image = transformed[:image.shape[0], ...]
        mass_mask = transformed[image.shape[0]:, ...]
    
    ...