Search code examples
pytorchimage-preprocessingpytorch-dataloader

Applying a simple transformation to get a binary image using pytorch


I'd like to binarize image before passing it to the dataloader, I have created a dataset class which works well. but in the __getitem__() method I'd like to threshold the image:

    def __getitem__(self, idx):
        # Open image, apply transforms and return with label
        img_path = os.path.join(self.dir, self.filelist[filename"])
        image = Image.open(img_path)
        label = self.x_data.iloc[idx]["label"]

        # Applying transformation to the image
        if self.transforms is not None:
           image = self.transforms(image)

        # applying threshold here:
        my_threshold = 240
        image = image.point(lambda p: p < my_threshold and 255)
        image = torch.tensor(image)

        return image, label

And then I tried to invoke the dataset:

    data_transformer = transforms.Compose([
        transforms.Resize((10, 10)),
        transforms.Grayscale()
        //transforms.ToTensor()
    ])

train_set = MyNewDataset(data_path, data_transformer, rows_train)

Since I have applied the threshold on a PIL object I need to apply afterwards a conversion to a tensor object , but for some reason it crashes. can somebody please assist me?


Solution

  • Why not apply the binarization after the conversion from PIL.Image to torch.Tensor?

    class ThresholdTransform(object):
      def __init__(self, thr_255):
        self.thr = thr_255 / 255.  # input threshold for [0..255] gray level, convert to [0..1]
    
      def __call__(self, x):
        return (x > self.thr).to(x.dtype)  # do not change the data type
    

    Once you have this transformation, you simply add it:

    data_transformer = transforms.Compose([
            transforms.Resize((10, 10)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            ThresholdTransform(thr_255=240)
        ])