I want to load a batch of images of different resolutions and split them into non-overlapping patches of equal sizes on the fly to feed them to a Resnet18 model, is there an existing transform class in PyTorch that does this, if not how do I implement my own class.
Here's the code:
transform = transforms.Compose([
ImageResizer(), # Custom class to resize the image to the next multiple of 224 (takes as input PIL image and returns PIL image)
#Patch(patch_size=(224, 224)), # Custom class to divide the image into patches of 224x224 (takes as input PIL image and returns a list of PIL images)
transforms.ToTensor(),
])
dataset = ImageFolder(root="<path>", transform=transform)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
Here's how my ImageResizer code looks:
class ImageResizer:
"""
A class to resize the image to the next multiple of 224, so that the images can be divided into 224x224 patches later.
"""
def __init__(self):
pass
def get_new_dimensions(width : int, height : int, patch_height : int = 224, patch_width : int = 224):
"""
Get the new dimensions of the image after resizing.
Parameters:
- width: The width of the image.
- height: The height of the image.
- patch_height: The height of the patch.
- patch_width: The width of the patch.
Returns:
- new_height: The new height of the image.
- new_width: The new width of the image.
"""
width_coef = int(np.round(width / patch_width).astype(np.int32))
height_coef = int(np.round(height / patch_height).astype(np.int32))
new_width = width_coef * patch_width
new_height = height_coef * patch_height
return new_width, new_height
def __call__(self, image):
"""
Resize the given image to the next multiple of 224.
Parameters:
- image: an image of type pillow.
Returns:
- resized_image: The resized image of type pillow.
"""
width, height = image.size
new_width, new_height = ImageResizer.get_new_dimensions(width, height)
# Resize the image
resized_image = image.resize((new_width, new_height))
return resized_image
transforms
are expected to take as an input one data point (an image in this case) and return a single transformed data point,thus patching an image using a custom transform
and returning a list of patches is not possible for now.
A possible solution is to provide a custom implementation for the collate_fn
function and pass it as an argument to the DataLoader
class.
The collate_fn
function takes as an input a list of tuples (the first element of the tuple is the data point and the second is the label),and returns a tuple of tow tensors,the first tensor represents a batch of images and the second one represents the corresponding labels.
Below you find a possible implementation of the functionality that you want :
def make_paches(
img : torch.Tensor,
patch_width : int,
patch_height : int
) -> list[torch.Tensor]:
patches = img \
.unfold(1,patch_width,patch_width) \
.unfold(2,patch_height,patch_height) \
.flatten(1,2) \
.permute(1,0,2,3)
patches = list(patches)
return patches
def collate_fn(batch : list[tuple[torch.Tensor, int]]) -> tuple[torch.Tensor, torch.Tensor]:
new_x = []
new_y = []
for x, y in batch:
patches = make_paches(x, 224, 224)
new_x.extend(patches)
new_y.extend([y for _ in range(len(patches))])
new_x = torch.stack(new_x)
new_y = torch.tensor(new_y)
return new_x,new_y
dataset = datasets.ImageFolder(root="<your-path>", transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)