Inside my training folder, I have 3 subfolders each with their own label. I want to apply a different transformation for each label. Let's say that my transformation is changing the little red square at the corner of the image, then I want my first category to have a red square, the second a blue square and the third a green square. How can I do that?
Here's what I have
class PerturbTransform(object):
def __call__(self, img):
img_np = np.array(img)
img_np[0:20, 0:20, :] = [255, 0, 0] # Add red square to the corner of the image
img_pil = Image.fromarray(img_np)
img_pil = transforms(img_pil) #this line calls my transforms.Compose, it just converts it to tensor
return img_pil
perturb_dataset = datasets.ImageFolder(data_dir + '/training', transform=PerturbTransform())
perturb_loader = torch.utils.data.DataLoader(perturb_dataset, batch_size=batch_size, shuffle=True)
Currently, I am changing every picture to have a red square when it is used for training. But I want the color to be different for each label.
If accessing the image's label is impossible, can I at least access the image index? I know how many images my training set has, so if I can access the index number of the image i'm transforming, I can work with that too.
You can either define a custom dataset, or defnie a custom image loader and pass it to ImageFolder as follows:
def custom_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, "rb",) as f:
img = Image.open(f)
img = img.convert("RGB")
img = np.array(img)
# path 'path_to_dataset/class/image'
label = os.path.basename(os.path.dirname(path))
if label == 'label1':
img[:20, :20, :] = 255 # Add red square to the corner of the image
elif label == 'label2':
img[:20, :20, :] = 0
# process
pass
img_pil = Image.fromarray(img)
return img_pil
tfms = Compose([transforms.CenterCrop(10),
transforms.ToTensor()],
)
perturb_dataset = datasets.ImageFolder('/content/dataset', transform=tfms,
loader=custom_loader)
perturb_loader = torch.utils.data.DataLoader(perturb_dataset, batch_size=2, shuffle=True)
custom_loader code is based on the library default loader https://pytorch.org/vision/main/_modules/torchvision/datasets/folder.html#ImageFolder.
Note: If your objective is to perform some calculations that may result in float values and converting the result back to an Image object results in a considered information loss; this approach won't help.