Search code examples
pytorchconvolutiongenerative-adversarial-networkdeconvolution

How to implement fractionally strided convolution layers in pytorch?


Before everything, I searched google and StackOverflow but I do not find any similar questions so here I propose a new one.

I'm interested in this paper and want to implement this SGAN for my project. The paper mentioned that its generator network is composed of "a stack of fractionally strided convolution layers", I found two different ways of implementing this in pytorch, one is:

torch.nn.Sequential(
    # other layers...
    torch.nn.ConvTranspose2d(),
    # other layers...
)

the other way is:

torch.nn.Sequential(
    # other layers...
    torch.nn.Upsample(scale_factor=2),
    torch.nn.Conv2D(),
    # other layers...
)

So, my question is, which is the better implementation of fractionally strided conv layer, or am I understanding something completely wrong?

Thanks in advance.

P.S, I found the second implementation here, in line 87 - 88.


Solution

  • tldr; There are some shape constraints but both perform the same operations.


    The output shape of nn.ConvTranspose2d is given by y = (x − 1)s - 2p + d(k-1) + p_out + 1, where x and y are the input and ouput shape, respectively, k is the kernel size, s the stride, d the dilation, p and p_out the padding and padding out. Here we keep things simple with s=1, p=0, p_out=0, d=1.

    Therefore, the output shape of the transposed convolution is:

    y =  x - 1 + k
    

    If we look at an upsample (x2) with convolution. Using the same notation as before, the output of nn.Conv2d is given by: y = floor((x + 2p - d(k - 1) - 1) / s + 1). After upsampling x is sized 2x. We keep the dilation at d=1.

    y = floor((2x + 2p - k) / s + 1)
    

    If we want to match the output shape of the transposed convolution, we need to have x - 1 + k = floor((2x + 2p - k) / s + 1). This relation will define the values to choose for s and p for our convolution.

    Taking a simple example for demonstration: k=2. Now x + 1 needs to be equal to floor((2x + 2p - k) / s + 1), which is solved by setting s=2 and p=1.


    Here is the same example in a visual form.

    enter image description here

    • transposed convolution

    enter image description here

    • upsample + convolution

    enter image description here