Search code examples
pytorchtransformtorchvision

With torch or torchvision, how can I resize and crop an image batch, and get both the resizing scales and the new images?


I want to transform a batch of images such that they are randomly cropped (with fixed ratio) and resized (scaled). However, I want not only the new images but also a tensor of the scale factors applied to each image. For example, this torchvision transform will do the cropping and resizing I want:

scale_transform = torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(1.0, 1.0))
images_scaled = scale_transform(images_original)

But I also want to know the scale factors. How might I get those scale factors, or tackle this in a different way?


Solution

  • If I understand correctly, you want to get ratio of how much cropped part was resized. You can get it by computing xy size of the cropped part and divide it by size you want to get.

    class MyRandomResizedCrop(object):
        def __init__(self, size, scale, ratio):
            self.t = torchvision.transforms.RandomResizedCrop(size, scale=scale, ratio=ratio)
            self.size = size
            self.scale = scale
            self.ratio = ratio
            
        def __call__(self, sample):
            sample = F.to_pil_image(sample)
            crop_size = self.t.get_params(sample, self.scale, self.ratio)
    
            x_size = crop_size[2] - crop_size[0]
            y_size = crop_size[3] - crop_size[1]
    
            x_ratio = sample.size[0] / x_size
            y_ratio = sample.size[1] / y_size
            ratio = (x_ratio, y_ratio) 
            
            output = F.crop(sample, *crop_size)
            output = F.resize(output, self.size)
            
            return ratio, output
            
    
    import torchvision
    from PIL import Image
    import torchvision.transforms.functional as F
    
    size = 244
    scale = (0.08, 1.0)
    ratio = (1.0, 1.0)
    
    t = MyRandomResizedCrop(size, scale, ratio)
    
    img = torch.rand((3,1024,1024), dtype=torch.float32)
    
    r, img = t(img)