I would like to load images from a local directory in Torch in order to train a binary classifier. My directory looks as follows:
-data
- class_1_folder
- class_2_folder
My folders class_1
and class_2
contain the .jpg images for each class. My images do have various sizes (mostly rectangular shapes though). I am using the following code to read my files:
import os
import pdb
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
PATH = "data/"
transform = transforms.Compose([transforms.Resize(256),
transforms.RandomCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225])])
dataset = datasets.ImageFolder(PATH, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=356, shuffle=True)
images, labels = next(iter(dataloader))
tensor_image = images[1]
img = to_pil_image(inv_normalize(tensor_image))
plt.imshow(img)
plt.show()
When I check these results with imshow
, it seems that the imshow
portrays a grid
with the image nine times (3X3
). How to avoid that? Is there a way also to revert the transformations easily before the imshow
?
I am trying to do something to invert
the normalization
like the following:
inv_normalize = transforms.Normalize(
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
std=[1/0.229, 1/0.224, 1/0.255]
)
However, the results are still a bit weird! Isn't that a correct invert
transformation?
You can use to_pil_image
from torchvision import transforms
from PIL import Image
from torchvision.transforms.functional import to_pil_image
tr = transforms.Compose([transforms.Resize(256),
transforms.RandomCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225])])
img = Image.open("/content/cat-2083492__340.jpg")
plt.imshow(img)
img = tr(img)
plt.figure()
print(img.shape, type(img))
img = img*(torch.tensor([0.229, 0.224,0.225])).reshape(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).reshape(3, 1, 1)
img = to_pil_image(img)
plt.imshow(img)
Note to_pil_image
will only change the order of dimensions
https://pytorch.org/vision/main/generated/torchvision.transforms.ToPILImage.html#torchvision.transforms.ToPILImage