Search code examples
pythonmachine-learningpytorchcomputer-visiontorchvision

In torchvision, how do can I apply a different transformation for each label?


Inside my training folder, I have 3 subfolders each with their own label. I want to apply a different transformation for each label. Let's say that my transformation is changing the little red square at the corner of the image, then I want my first category to have a red square, the second a blue square and the third a green square. How can I do that?

Here's what I have

class PerturbTransform(object):
    def __call__(self, img):
        img_np = np.array(img)
        img_np[0:20, 0:20, :] = [255, 0, 0] # Add red square to the corner of the image
        img_pil = Image.fromarray(img_np)
        img_pil = transforms(img_pil) #this line calls my transforms.Compose, it just converts it to tensor

        return img_pil

perturb_dataset = datasets.ImageFolder(data_dir + '/training', transform=PerturbTransform())
perturb_loader = torch.utils.data.DataLoader(perturb_dataset, batch_size=batch_size, shuffle=True)

Currently, I am changing every picture to have a red square when it is used for training. But I want the color to be different for each label.

If accessing the image's label is impossible, can I at least access the image index? I know how many images my training set has, so if I can access the index number of the image i'm transforming, I can work with that too.


Solution

  • You can either define a custom dataset, or defnie a custom image loader and pass it to ImageFolder as follows:

    def custom_loader(path: str) -> Image.Image:
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, "rb",) as f:
             img = Image.open(f)
             img = img.convert("RGB")
        img = np.array(img)
        # path 'path_to_dataset/class/image'
        label = os.path.basename(os.path.dirname(path))
        if label == 'label1':
          img[:20, :20, :] = 255 # Add red square to the corner of the image
        elif label == 'label2':
          img[:20, :20, :] = 0
          # process
          pass
        img_pil = Image.fromarray(img)
        return img_pil
    
    tfms = Compose([transforms.CenterCrop(10),
    transforms.ToTensor()],
    )
    perturb_dataset = datasets.ImageFolder('/content/dataset', transform=tfms,
                                           loader=custom_loader)
    perturb_loader = torch.utils.data.DataLoader(perturb_dataset, batch_size=2, shuffle=True)
    

    custom_loader code is based on the library default loader https://pytorch.org/vision/main/_modules/torchvision/datasets/folder.html#ImageFolder.

    Note: If your objective is to perform some calculations that may result in float values and converting the result back to an Image object results in a considered information loss; this approach won't help.