Search code examples
filterpytorchtorchvision

Pyorch: Applying a batch of filters (kernels) on one single picture using conv2d


I have a batch of filters, i.e., w, whose size is torch.Size([64, 3, 7, 7]) as follows:

enter image description here

Also, I have a picture p from Imagenet as follows:

enter image description here

How can I apply the filters to the picture and get a grid of 64x64 where each cell contains the same picture on which a different filter has been applied? I would like to make the grid using torchvision.utils.make_grid but do not know how?

My try

y = F.conv2d(p, w)

The size of y is torch.Size([1, 64, 250, 250]) which does not make sense to me.


Solution

  • Each of your filters has size [3, 7, 7], so they would take an RGB image and produce a single channel output which is stacked in the channel dimension so your output [1, 64, H, W] makes perfect sense.

    To visualize these filters:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torchvision
    from torchvision import transforms
    from PIL import Image
    
    import matplotlib.pyplot as plt
    
    torch.random.manual_seed(42)
    
    transform = transforms.Compose([transforms.ToTensor()])
    img = transform(Image.open('dog.jpg')).unsqueeze(0)
    print('Image size: ', img.shape)
    
    filters = torch.randn(64, 3, 7, 7)
    
    out = F.conv2d(img, filters)
    print('Output size: ', out.shape)
    
    list_of_images = [out[:,i] for i in range(64)]
    
    grid = torchvision.utils.make_grid(list_of_images, normalize=True)
    plt.imshow(grid.numpy().transpose(1,2,0))
    plt.show()
    

    enter image description here

    This is a more accurate representation of the output. It is however not very attractive -- we can obtain the colored version by processing each color channel independently. (The grayscale version can be obtained by summing over the color channels)

    color_out = []
    for i in range(3):
        color_out.append(F.conv2d(img[:,i:i+1], filters[:,i:i+1]))
    out = torch.stack(color_out, 2)
    print('Output size: ', out.shape)
    
    list_of_images = [out[0,i] for i in range(64)]
    print(list_of_images[0].shape)
    
    grid = torchvision.utils.make_grid(list_of_images, normalize=True)
    plt.imshow(grid.numpy().transpose(1,2,0))
    plt.show()
    

    enter image description here