In Pytorch, I have a dataset of 2D images (or alternatively, 1 channel images) and I'd like to apply average 2D pooling as a transform. How do I do this? The following does not work:
omniglot_dataset = torchvision.datasets.Omniglot(
root=data_dir,
download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.CenterCrop((80, 80)),
# torchvision.transforms.Resize((10, 10))
torch.nn.functional.avg_pool2d(kernel_size=3, strides=1),
])
)
Transforms have to be a callable object. But torch.nn.functional.avg_pool2d doesn't return a callable object, but rather it is just a function you can call to process, that is why they are packaged under torch.nn.functional where all functionals receives the input and parameters. You need to use the other version:
torch.nn.AvgPool2d(kernel_size=3, stride=1)
Which returns a callable object, that can be called to process a given input, for example:
pooler = torch.nn.AvgPool2d(kernel_size=3, stride=1)
output = pooler(input)
With this change here you can see different versions how you can use callable version:
import torchvision
import torch
import matplotlib.pyplot as plt
omniglotv1 = torchvision.datasets.Omniglot(
root='./dataset/',
download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.CenterCrop((80, 80))
])
)
x1, y = omniglotv1[0]
print(x1.size()) # torch.Size([1, 80, 80])
omniglotv2 = torchvision.datasets.Omniglot(
root='./dataset/',
download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.CenterCrop((80, 80)),
torch.nn.AvgPool2d(kernel_size=3, stride=1)
])
)
x2, y = omniglotv2[0]
print(x2.size()) # torch.Size([1, 78, 78])
pooler = torch.nn.AvgPool2d(kernel_size=3, stride=1)
omniglotv3 = torchvision.datasets.Omniglot(
root='./dataset/',
download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.CenterCrop((80, 80)),
pooler
])
)
x3, y = omniglotv3[0]
print(x3.size()) # torch.Size([1, 78, 78])
Here, I just added a short code for image printing to see how the transform looks:
x_img = x1.squeeze().cpu().numpy()
ave_img = x2.squeeze().cpu().numpy()
combined = np.zeros((158,80))
combined[0:80,0:80] = x_img
combined[80:,0:78] = ave_img
plt.imshow(combined)
plt.show()