Search code examples
pytorchfftconvolutiontheorem-proving

Verify convolution theorem using pytorch


Basically this theorem is formulated as below:

F(f*g) = F(f)xF(g)

I know this theorem but I just simply cannot reproduce the result by using pytorch.

Below is a reproducable code:

import torch
import torch.nn.functional as F

# calculate f*g
f = torch.ones((1,1,5,5))
g = torch.tensor(list(range(9))).view(1,1,3,3).float()
conv = F.conv2d(f, g, bias=None, padding=2)

# calculate F(f*g)
F_fg = torch.rfft(conv, signal_ndim=2, onesided=False)

# calculate F x G
f = f.squeeze()
g = g.squeeze()

# need to pad into at least [w1+w2-1, h1+h2-1], which is 7 in our case.
size = f.size(0) + g.size(0) - 1 

f_new = torch.zeros((7,7))
g_new = torch.zeros((7,7))

f_new[1:6,1:6] = f
g_new[2:5,2:5] = g

F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
FxG = torch.mul(F_f, F_g)

print(FxG - F_fg)

here is the result for print(FxG - F_fg)

tensor([[[[[ 0.0000e+00,  0.0000e+00],
       [ 4.1426e+02,  1.7270e+02],
       [-3.6546e+01,  4.7600e+01],
       [-1.0216e+01, -4.1198e+01],
       [-1.0216e+01, -2.0223e+00],
       [-3.6546e+01, -6.2804e+01],
       [ 4.1426e+02, -1.1427e+02]],

      ...

      [[ 4.1063e+02, -2.2347e+02],
       [-7.6294e-06,  2.2817e+01],
       [-1.9024e+01, -9.0105e+00],
       [ 7.1708e+00, -4.1027e+00],
       [-2.6739e+00, -1.1121e+01],
       [ 8.8471e+00,  7.1710e+00],
       [ 4.2528e+01,  9.7559e+01]]]]])

and you can see that the difference is not always 0.

can someone tell me why and how to do this properly?

Thanks


Solution

  • So I took a closer look at what you've done so far. I've identified three sources of error in your code. I'll try to sufficiently address each of them here.

    1. Complex arithmetic

    PyTorch doesn't currently support multiplication of complex numbers (AFAIK). The FFT operation simply returns a tensor with a real and imaginary dimension. Instead of using torch.mul or the * operator we need to explicitly code complex multiplication.

    (a + ib) * (c + id) = (a*c - b*d) + i(a*d + b*c)

    2. The definition of convolution

    The definition of "convolution" often used in CNN literature is actually different from the definition used when discussing the convolution theorem. I won't go into detail, but the theoretical definition flips the kernel before sliding and multiplying. Instead, the convolution operation in pytorch, tensorflow, caffe, etc... doesn't do this flipping.

    To account for this we can simply flip g (both horizontally and vertically) before applying the FFT.

    3. Anchor position

    The anchor-point when using the convolution theorem is assumed to be the upper left corner of the padded g. Again, I won't go into detail about this but it's how the math works out.


    The second and third point may be easier to understand with an example. Suppose you used the following g

    [1 2 3]
    [4 5 6]
    [7 8 9]
    

    instead of g_new being

    [0 0 0 0 0 0 0]
    [0 0 0 0 0 0 0]
    [0 0 1 2 3 0 0]
    [0 0 4 5 6 0 0]
    [0 0 7 8 9 0 0]
    [0 0 0 0 0 0 0]
    [0 0 0 0 0 0 0]
    

    it should actually be

    [5 4 0 0 0 0 6]
    [2 1 0 0 0 0 3]
    [0 0 0 0 0 0 0]
    [0 0 0 0 0 0 0]
    [0 0 0 0 0 0 0]
    [0 0 0 0 0 0 0]
    [8 7 0 0 0 0 9]
    

    where we flip the kernel vertically and horizontally, then apply circular shift so that the center of the kernel is in the upper left corner.


    I ended up rewriting most of your code and generalizing it a bit. The most complex operation is defining g_new properly. I decided to use a meshgrid and modulo arithmetic to simultaneously flip and shift the indices. If something here doesn't make sense to you please leave a comment and I'll try to clarify.

    import torch
    import torch.nn.functional as F
    
    def conv2d_pyt(f, g):
        assert len(f.size()) == 2
        assert len(g.size()) == 2
    
        f_new = f.unsqueeze(0).unsqueeze(0)
        g_new = g.unsqueeze(0).unsqueeze(0)
    
        pad_y = (g.size(0) - 1) // 2
        pad_x = (g.size(1) - 1) // 2
    
        fcg = F.conv2d(f_new, g_new, bias=None, padding=(pad_y, pad_x))
        return fcg[0, 0, :, :]
    
    def conv2d_fft(f, g):
        assert len(f.size()) == 2
        assert len(g.size()) == 2
    
        # in general not necessary that inputs are odd shaped but makes life easier
        assert f.size(0) % 2 == 1
        assert f.size(1) % 2 == 1
        assert g.size(0) % 2 == 1
        assert g.size(1) % 2 == 1
    
        size_y = f.size(0) + g.size(0) - 1
        size_x = f.size(1) + g.size(1) - 1
    
        f_new = torch.zeros((size_y, size_x))
        g_new = torch.zeros((size_y, size_x))
    
        # copy f to center
        f_pad_y = (f_new.size(0) - f.size(0)) // 2
        f_pad_x = (f_new.size(1) - f.size(1)) // 2
        f_new[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x] = f
    
        # anchor of g is 0,0 (flip g and wrap circular)
        g_center_y = g.size(0) // 2
        g_center_x = g.size(1) // 2
        g_y, g_x = torch.meshgrid(torch.arange(g.size(0)), torch.arange(g.size(1)))
        g_new_y = (g_y.flip(0) - g_center_y) % g_new.size(0)
        g_new_x = (g_x.flip(1) - g_center_x) % g_new.size(1)
        g_new[g_new_y, g_new_x] = g[g_y, g_x]
    
        # take fft of both f and g
        F_f = torch.rfft(f_new, signal_ndim=2, onesided=False)
        F_g = torch.rfft(g_new, signal_ndim=2, onesided=False)
    
        # complex multiply
        FxG_real = F_f[:, :, 0] * F_g[:, :, 0] - F_f[:, :, 1] * F_g[:, :, 1]
        FxG_imag = F_f[:, :, 0] * F_g[:, :, 1] + F_f[:, :, 1] * F_g[:, :, 0]
        FxG = torch.stack([FxG_real, FxG_imag], dim=2)
    
        # inverse fft
        fcg = torch.irfft(FxG, signal_ndim=2, onesided=False)
    
        # crop center before returning
        return fcg[f_pad_y:-f_pad_y, f_pad_x:-f_pad_x]
    
    
    # calculate f*g
    f = torch.randn(11, 7)
    g = torch.randn(5, 3)
    
    fcg_pyt = conv2d_pyt(f, g)
    fcg_fft = conv2d_fft(f, g)
    
    avg_diff = torch.mean(torch.abs(fcg_pyt - fcg_fft)).item()
    
    print('Average difference:', avg_diff)
    

    Which gives me

    Average difference: 4.6866085767760524e-07
    

    This is very close to zero. The reason we don't get exactly zero is simply due to floating point errors.