Search code examples
pythonpytorchtransformationtorchvision

How to use torchvision.transforms for data augmentation of segmentation task in Pytorch?


I am a little bit confused about the data augmentation performed in PyTorch.

Because we are dealing with segmentation tasks, we need data and mask for the same data augmentation, but some of them are random, such as random rotation.

Keras provides a random seed guarantee that data and mask do the same operation, as shown in the following code:

    data_gen_args = dict(featurewise_center=True,
                         featurewise_std_normalization=True,
                         rotation_range=25,
                         horizontal_flip=True,
                         vertical_flip=True)


    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    seed = 1
    image_generator = image_datagen.flow(train_data, seed=seed, batch_size=1)
    mask_generator = mask_datagen.flow(train_label, seed=seed, batch_size=1)

    train_generator = zip(image_generator, mask_generator)

I didn't find a similar description in the official Pytorch documentation, so I don't know how to ensure that data and mask can be processed synchronously.

Pytorch does provide such a function, but I want to apply it to a custom Dataloader.

For example:

def __getitem__(self, index):
    img = np.zeros((self.im_ht, self.im_wd, channel_size))
    mask = np.zeros((self.im_ht, self.im_wd, channel_size))

    temp_img = np.load(Image_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')

    for i in range(channel_size):
        img[:,:,i] = temp_img[self.count[index] + i]
        mask[:,:,i] = temp_label[self.count[index] + i]

    if self.transforms:
        img = np.uint8(img)
        mask = np.uint8(mask)
        img = self.transforms(img)
        mask = self.transforms(mask)

    return img, mask

In this case, img and mask will be transformed separately, because some operations such as random rotation are random, so the correspondence between mask and image may be changed. In other words, the image may have rotated but the mask did not do this.

EDIT 1

I used the method in augmentations.py, but I got an error::

Traceback (most recent call last):
  File "test_transform.py", line 87, in <module>
    for batch_idx, image, mask in enumerate(train_loader):
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 103, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/data.py", line 164, in __getitem__
    img, mask = self.transforms(img, mask)
  File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/augmentations.py", line 17, in __call__
    img, mask = a(img, mask)
TypeError: __call__() takes 2 positional arguments but 3 were given

This is my code for __getitem__()

data_transforms = {
    'train': Compose([
        RandomHorizontallyFlip(),
        RandomRotate(degree=25),
        transforms.ToTensor()
    ]),
}

train_set = DatasetUnetForTestTransform(fold=args.fold, random_index=args.random_index,transforms=data_transforms['train'])

# __getitem__ in class DatasetUnetForTestTransform
def __getitem__(self, index):
    img = np.zeros((self.im_ht, self.im_wd, channel_size))
    mask = np.zeros((self.im_ht, self.im_wd, channel_size))
    temp_img = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_img, temp_label = crop_data_label_from_0(temp_img, temp_label)
    for i in range(channel_size):
        img[:,:,i] = temp_img[self.count[index] + i]
        mask[:,:,i] = temp_label[self.count[index] + i]

    if self.transforms:
        img = T.ToPILImage()(np.uint8(img))
        mask = T.ToPILImage()(np.uint8(mask))
        img, mask = self.transforms(img, mask)

    img = T.ToTensor()(img).copy()
    mask = T.ToTensor()(mask).copy()
    return img, mask

EDIT 2

I found that after ToTensor, the dice between the same labels becomes 255 instead of 1, how to fix it?

# Dice computation
def DSC_computation(label, pred):
    pred_sum = pred.sum()
    label_sum = label.sum()
    inter_sum = np.logical_and(pred, label).sum()
    return 2 * float(inter_sum) / (pred_sum + label_sum)

Feel free to ask if more code is needed to explain the problem.


Solution

  • torchvision also provides similar functions [document].

    Here is a simple example,

    import torchvision
    from torchvision import transforms
    
    trans = transforms.Compose([transforms.CenterCrop((178, 178)),
                                        transforms.Resize(128),
                                        transforms.RandomRotation(20),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    dset = torchvision.datasets.MNIST(data_root, transforms=trans)
    

    EDIT

    A brief example when customizing your own CelebA dataset. Note that, to apply transformations, you need call transform list in __getitem__.

    class CelebADataset(Dataset):
        def __init__(self, root, transforms=None, num=None):
            super(CelebADataset, self).__init__()
    
            self.img_root = os.path.join(root, 'img_align_celeba')
            self.attr_root = os.path.join(root, 'Anno/list_attr_celeba.txt')
            self.transforms = transforms
    
            df = pd.read_csv(self.attr_root, sep='\s+', header=1, index_col=0)
            #print(df.columns.tolist())
            if num is None:
                self.labels = df.values
                self.img_name = df.index.values
            else:
                self.labels = df.values[:num]
                self.img_name = df.index.values[:num]
    
        def __getitem__(self, index):
            img = Image.open(os.path.join(self.img_root, self.img_name[index]))
            # only use blond_hair, eyeglass, male, smile
            indices = [9, 15, 20, 31]
            label = np.take(self.labels[index], indices)
            label[label==-1] = 0
    
            if self.transforms is not None:
                img = self.transforms(img)
    
            return np.asarray(img), label
    
        def __len__(self):
            return len(self.labels)
    
    

    EDIT 2

    I probably miss something at the first glance. The main point of your problem is how to apply "the same" data preprocessing to img and labels. To my understanding, there is no available Pytorch built-in function. So, what I did before is to implement the augmentation by myself.

    class RandomRotate(object):
        def __init__(self, degree):
            self.degree = degree
    
        def __call__(self, img, mask):
            rotate_degree = random.random() * 2 * self.degree - self.degree
            return img.rotate(rotate_degree, Image.BILINEAR), 
                               mask.rotate(rotate_degree, Image.NEAREST)
    

    Note that the input should be PIL format. See this for more information.