Search code examples
pytorchconv-neural-network

Algorithim of how Conv2d is implemented in PyTorch


I am working on an inference model of a pytorch onnx model which is why this question is being asked.

Assume, I have a image with dimensions 32 x 32 x 3 (CIFAR-10 dataset). I pass it through a Conv2d with dimensions : 3 x 192 x 5 x 5. The command I used is: Conv2d(3, 192, kernel_size=5, stride=1, padding=2)

Using the formula (stated here for reference pg12 https://arxiv.org/pdf/1603.07285.pdf) I should be getting an output image with dimensions 28 x 28 x 192 (input - kernel + 1 = 32 - 5 + 1).

Question is how has PyTorch implemented this 4d tensor 3 x 192 x 5 x 5 to get me an output of 28 x 28 x 192 ? The layer is a 4d tensor and the input image is a 2d one.

How is the kernel (5x5) spread in the image matrix 32 x 32 x 3 ? What does the kernel convolve with first -> 3 x 192 or 32 x 32?

Note : I have understood the 2d aspects of things. I am asking the above questions in 3 or more.


Solution

  • The input to Conv2d is a tensor of shape (N, C_in, H_in, W_in) and the output is of shape (N, C_out, H_out, W_out), where N is the batch size (number of images), C is the number of channels, H is the height and W is the width. The output height and width H_out, W_out are computed as follows (ignoring the dilation):

    H_out = (H_in + 2*padding[0] - kernel_size[0]) / stride[0] + 1
    W_out = (W_in + 2*padding[1] - kernel_size[1]) / stride[1] + 1
    

    See cs231n for an explanation of how this formulas were obtained.

    In your example N=1, H_in = 32, W_in = 32, C_in = 3, kernel_size = (5, 5), strides = (1, 1), padding = (0, 0), giving H_out = 28, W_out = 28.

    The C_out=192 means that there are 192 different filters, each of shape (C_in, kernel_size[0], kernel_size[1]) = (3, 5, 5). Each filter independently performs convolution with the input image resulting in a 2D tensor of shape (H_out, W_out) = (28, 28), and since there are C_out = 192 filters and N = 1 images, the final output is of shape (N, C_out, H_out, W_out) = (1, 192, 28, 28).

    To understand how exactly the convolution is performed see the convolution demo.

    enter image description here