Search code examples
pytorchtorchvision

PyTorch - How to use Avg 2d Pooling as a dataset transform?


In Pytorch, I have a dataset of 2D images (or alternatively, 1 channel images) and I'd like to apply average 2D pooling as a transform. How do I do this? The following does not work:

    omniglot_dataset = torchvision.datasets.Omniglot(
        root=data_dir,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.CenterCrop((80, 80)),
            # torchvision.transforms.Resize((10, 10))
            torch.nn.functional.avg_pool2d(kernel_size=3, strides=1),
        ])
    )

Solution

  • Transforms have to be a callable object. But torch.nn.functional.avg_pool2d doesn't return a callable object, but rather it is just a function you can call to process, that is why they are packaged under torch.nn.functional where all functionals receives the input and parameters. You need to use the other version:

    torch.nn.AvgPool2d(kernel_size=3, stride=1)
    

    Which returns a callable object, that can be called to process a given input, for example:

    pooler = torch.nn.AvgPool2d(kernel_size=3, stride=1)
    output = pooler(input)
    

    With this change here you can see different versions how you can use callable version:

    import torchvision
    import torch
    import matplotlib.pyplot as plt
    
    omniglotv1 = torchvision.datasets.Omniglot(
            root='./dataset/',
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.CenterCrop((80, 80))
            ])
        )
    
    x1, y = omniglotv1[0]
    print(x1.size())      # torch.Size([1, 80, 80])
    
    omniglotv2 = torchvision.datasets.Omniglot(
            root='./dataset/',
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.CenterCrop((80, 80)),
                torch.nn.AvgPool2d(kernel_size=3, stride=1)
            ])
        )
    
    x2, y = omniglotv2[0]
    print(x2.size())      # torch.Size([1, 78, 78])
    
    pooler = torch.nn.AvgPool2d(kernel_size=3, stride=1)
    omniglotv3 = torchvision.datasets.Omniglot(
            root='./dataset/',
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.CenterCrop((80, 80)),
                pooler
            ])
        )
    
    x3, y = omniglotv3[0]
    print(x3.size())      # torch.Size([1, 78, 78])
    

    Here, I just added a short code for image printing to see how the transform looks:

    x_img   = x1.squeeze().cpu().numpy()
    ave_img = x2.squeeze().cpu().numpy()
    combined = np.zeros((158,80))
    combined[0:80,0:80] = x_img
    combined[80:,0:78] = ave_img
    plt.imshow(combined)
    plt.show()
    

    enter image description here