Search code examples
deep-learningpytorchcomputer-visionresnetimage-preprocessing

How to load a batch of images of and split them into patches on the fly with PyTorch>


I want to load a batch of images of different resolutions and split them into non-overlapping patches of equal sizes on the fly to feed them to a Resnet18 model, is there an existing transform class in PyTorch that does this, if not how do I implement my own class.

Here's the code:

transform = transforms.Compose([
    ImageResizer(), # Custom class to resize the image to the next multiple of 224 (takes as input PIL image and returns PIL image) 
    #Patch(patch_size=(224, 224)), # Custom class to divide the image into patches of 224x224 (takes as input PIL image and returns a list of PIL images)
    transforms.ToTensor(),
])

dataset = ImageFolder(root="<path>", transform=transform)

batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Here's how my ImageResizer code looks:

class ImageResizer:
    """
    A class to resize the image to the next multiple of 224, so that the images can be divided into 224x224 patches later.
    """

    def __init__(self):
        pass

    def get_new_dimensions(width : int, height : int, patch_height : int = 224, patch_width : int = 224):
        """
        Get the new dimensions of the image after resizing.

        Parameters:
        - width: The width of the image.
        - height: The height of the image.
        - patch_height: The height of the patch.
        - patch_width: The width of the patch.

        Returns:
        - new_height: The new height of the image.
        - new_width: The new width of the image.
        """
    
        width_coef = int(np.round(width / patch_width).astype(np.int32))
        height_coef = int(np.round(height / patch_height).astype(np.int32))

        new_width = width_coef * patch_width
        new_height = height_coef * patch_height

        return new_width, new_height

    def __call__(self, image):
        """
        Resize the given image to the next multiple of 224.

        Parameters:
        - image: an image of type pillow.

        Returns:
        - resized_image: The resized image of type pillow.
        """

        width, height = image.size

        new_width, new_height = ImageResizer.get_new_dimensions(width, height)

        # Resize the image
        resized_image = image.resize((new_width, new_height))

        return resized_image

Solution

  • transforms are expected to take as an input one data point (an image in this case) and return a single transformed data point,thus patching an image using a custom transform and returning a list of patches is not possible for now.

    A possible solution is to provide a custom implementation for the collate_fn function and pass it as an argument to the DataLoader class.

    The collate_fn function takes as an input a list of tuples (the first element of the tuple is the data point and the second is the label),and returns a tuple of tow tensors,the first tensor represents a batch of images and the second one represents the corresponding labels.

    Below you find a possible implementation of the functionality that you want :

    def make_paches(
        img : torch.Tensor,
        patch_width : int,
        patch_height : int
    ) -> list[torch.Tensor]:
    
        patches = img \
            .unfold(1,patch_width,patch_width) \
            .unfold(2,patch_height,patch_height) \
            .flatten(1,2) \
            .permute(1,0,2,3)
    
        patches = list(patches)
        return patches
    
    def collate_fn(batch : list[tuple[torch.Tensor, int]]) -> tuple[torch.Tensor, torch.Tensor]:
        
        new_x = []
        new_y = []
        
        for x, y in batch:
            patches = make_paches(x, 224, 224)
            new_x.extend(patches)
            new_y.extend([y for _ in range(len(patches))])
    
        new_x = torch.stack(new_x)
        new_y = torch.tensor(new_y)
        
        return new_x,new_y
    
    dataset = datasets.ImageFolder(root="<your-path>", transform=transform)
                
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)