Search code examples
pythonpytorchclassification

Read images from folder to tensors in torch and run a binary classifier


I would like to load images from a local directory in Torch in order to train a binary classifier. My directory looks as follows:

-data
   - class_1_folder
   - class_2_folder

My folders class_1 and class_2 contain the .jpg images for each class. My images do have various sizes (mostly rectangular shapes though). I am using the following code to read my files:

import os
import pdb
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image

PATH = "data/"

transform = transforms.Compose([transforms.Resize(256),
                            transforms.RandomCrop(224),
                            transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225])])

dataset = datasets.ImageFolder(PATH, transform=transform)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=356, shuffle=True)
images, labels = next(iter(dataloader))
tensor_image = images[1]
img = to_pil_image(inv_normalize(tensor_image))
plt.imshow(img)
plt.show()

When I check these results with imshow, it seems that the imshow portrays a grid with the image nine times (3X3). How to avoid that? Is there a way also to revert the transformations easily before the imshow?

I am trying to do something to invert the normalization like the following:

inv_normalize = transforms.Normalize(
                            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
                            std=[1/0.229, 1/0.224, 1/0.255]
)  

However, the results are still a bit weird! Isn't that a correct invert transformation?


Solution

  • You can use to_pil_image

    from torchvision import transforms
    from PIL import Image
    from torchvision.transforms.functional import to_pil_image
    
    tr = transforms.Compose([transforms.Resize(256),
                                transforms.RandomCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225])])
    img = Image.open("/content/cat-2083492__340.jpg")
    plt.imshow(img)
    
    img = tr(img)
    plt.figure()
    print(img.shape, type(img))
    img = img*(torch.tensor([0.229, 0.224,0.225])).reshape(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
    img = to_pil_image(img)
    plt.imshow(img)
    

    Note to_pil_image will only change the order of dimensions
    https://pytorch.org/vision/main/generated/torchvision.transforms.ToPILImage.html#torchvision.transforms.ToPILImage