Efficient pseudo-inverse for PyTorch 2D convolution


Thanks for your attention! I am learning the basic knowledge of 2D convolution, linear algebra and PyTorch. I encounter the implementation problem about the psedo-inverse of the convolution operator. Specifically, I have no idea about how to implement it in an efficient way. Please see the following problem statements for details. Any help/tip/suggestion is welcomed.

The Original Problem:

I have an image feature x with shape [b,c,h,w] and a 3x3 convolutional kernel K with shape [c,c,3,3]. There is y = K * x. How to implement the corresponding pseudo-inverse on y in an efficient way?

There is [y = K * x = Ax], how to implement [x_hat = (A^+)y]?

I guess that there should be some operations using torch.fft. However, I still have no idea about how to implement it. I do not know if there exists an implementation previously.

import torch
import torch.nn.functional as F

c = 32
K = torch.randn(c, c, 3, 3)
x = torch.randn(1, c, 128, 128)
y = F.conv2d(x, K, padding=1)


# How to implement pseudo-inverse for y = K * x in an efficient way?

Some of My Efforts:

I may know that the 2D convolution is a linear operator. It is equivalent to a "matrix product" operator. We can actually write out the matrix form of the convolution and calculate its psedo-inverse. However, I think this type of operation will be inefficient. And I have no idea about how to implement it in an efficient way.

According to Wikipedia, the psedo-inverse may satisfy the property of A(A_pinv(x))=x, where A is the convolutional operator, A_pinv is its psedo-inverse, and x may be any image feature.

  • This takes the problem to another level. The convolution itself is a linear operation, you can determine the matrix of the operation and solve a least square problem directly [1], or compute the pseudo-inverse as you mentioned, and then apply to different outputs and predicting a projection of the input.

    I am changing your code to using padding=0

    import torch
    import torch.nn.functional as F
    # your code
    c = 32
    K = torch.randn(c, c, 1, 1)
    x = torch.randn(4, c, 128, 128)
    y = F.conv2d(x, K, bias=torch.zeros((c,)))

    Also, as you probably already suggested the convolution can be computed as ifft(fft(h)*fft(x)). However, the conv2d function is a cross-correlation, so you have to conjugate the filter leading to ifft(fft(h)*fft(x)), also you have to apply this to two axes, and you have to make sure the FFT is calcuated using the same representation (size), since the data is real, we can apply multi-dimensional real FFT. To be complete, conv2d works on multiple channels, so we have to calculate summations of convolutions. Since the FFT is linear, we can simply compute the summations on the frequency domain using einsum.

    s = y.shape[-2:]
    K_f = torch.fft.rfftn(K, s)
    x_f = torch.fft.rfftn(x, s)
    y_f = torch.einsum('jkxy,ikxy->ijxy', K_f.conj(), x_f)
    y_hat = torch.fft.irfftn(y_f, s)

    Except for the borders it should be accurate (remember FFT computes a cyclic convolution).

    torch.max(abs(y_hat[:,:,:-2,:-2] - y[:,:,:,:]))

    Now, notice the pattern jk,ik->ij on the einsum, that means y_f[i,j] = sum(K_f[j,k] * x_f[i,k]) = x_f @ K_f.T, if @ is the matrix product on the first two dimensions. So to invert this operation we have to can interpret the first two dimensions as matrices. The function pinv will compute pseudo-inverses on the last two axes, so in order to use that we have to permute the axes. If we right multiply the output by the pseudo-inverse of transposed K_f we should invert this operation.

    s = 128,128
    K_f = torch.fft.rfftn(K, s)
    K_f_inv = torch.linalg.pinv(K_f.T).T
    y_f = torch.fft.rfftn(y_hat, s)
    x_f = torch.einsum('jkxy,ikxy->ijxy', K_f_inv.conj(), y_f)
    x_hat = torch.fft.irfftn(x_f, s)
    print(torch.mean((x - x_hat)**2) / torch.mean((x)**2))

    Nottice that I am using the full convolution, but the conv2d actually cropped the images. Let's apply that

    y_hat[:,:,128-(k-1):,:] = 0
    y_hat[:,:,:,128-(k-1):] = 0

    Repeating the calculation you will see that the input is not accurate anymore, so you have to be careful about what you do with your convolution, but in some situations where you can get this to work it will be in fact efficient.

    s = 128,128
    K_f = torch.fft.rfftn(K, s)
    K_f_inv = torch.linalg.pinv(K_f.T).T
    y_f = torch.fft.rfftn(y_hat, s)
    x_f = torch.einsum('jkxy,ikxy->ijxy', K_f_inv.conj(), y_f)
    x_hat = torch.fft.irfftn(x_f, s)
    print(torch.mean((x - x_hat)**2) / torch.mean((x)**2))