Search code examples
pythonimage-processingpytorchfft

Upsampling images in frequency domain using Pytorch


I'm trying to upsample an RGB image in the frequency domain, using Pytorch. I'm using this article for reference on grayscale images. Since Pytorch processes the channels individually, I figure the colorspace is irrelevant here.

The basic steps outlined by this article are:

  1. Perform FFT on the image.

  2. Pad the FFT with zeros.

  3. Perform inverse FFT.

I wrote the following code for the same:

import torch
import cv2
import numpy as np


img = src = cv2.imread('orig.png')
torch_img = torch.from_numpy(img).to(torch.float32).permute(2, 0, 1) / 255.
fft = torch.fft.fft2(torch_img, norm="forward")
fr = fft.real
fi = fft.imag
fr = F.pad(fr, (fft.shape[-1]//2, fft.shape[-1]//2, fft.shape[-2]//2, fft.shape[-2]//2), mode='constant', value=0)
fi = F.pad(fi, (fft.shape[-1]//2, fft.shape[-1]//2, fft.shape[-2]//2, fft.shape[-2]//2), mode='constant', value=0)

fft_hires = torch.complex(fr, fi)
inv = torch.fft.ifft2(fft_hires, norm="forward").real

print(inv.max(), inv.min())
img = (inv.permute(1, 2, 0).detach()).clamp(0, 1)
img = (255 * img).numpy().astype(np.uint8)
cv2.imwrite('hires.png', img)

The original image:

The original image:

The upscaled image:

The upscaled image

Another interesting thing to note is the maximum and minimum values of the image pixels after performing IFFT: they are 2.2729 and -1.8376 respectively. Ideally, they should be 1.0 and 0.0.

Can someone please explain what's wrong here?


Solution

  • The usual convention for the DFT is to treat the first sample as 0Hz component. But you need to have the 0Hz component in the center in order for padding to make sense. Most FFT tools provide a shift function to circularly shift your result so that the 0Hz component is in the center. In pytorch you need to perform torch.fft.fftshift after the FFT and torch.fft.ifftshift right before taking the inverse FFT to put the 0Hz component back in the upper left corner.

    import torch
    import torch.nn.functional as F
    import cv2
    import numpy as np
    
    
    img = src = cv2.imread('orig.png')
    torch_img = torch.from_numpy(img).to(torch.float32).permute(2, 0, 1) / 255.
    # note the fftshift
    fft = torch.fft.fftshift(torch.fft.fft2(torch_img, norm="forward"))
    
    fr = fft.real
    fi = fft.imag
    fr = F.pad(fr, (fft.shape[-1]//2, fft.shape[-1]//2, fft.shape[-2]//2, fft.shape[-2]//2), mode='constant', value=0)
    fi = F.pad(fi, (fft.shape[-1]//2, fft.shape[-1]//2, fft.shape[-2]//2, fft.shape[-2]//2), mode='constant', value=0)
    
    # note the ifftshift
    fft_hires = torch.fft.ifftshift(torch.complex(fr, fi))
    inv = torch.fft.ifft2(fft_hires, norm="forward").real
    
    print(inv.max(), inv.min())
    img = (inv.permute(1, 2, 0).detach()).clamp(0, 1)
    img = (255 * img).numpy().astype(np.uint8)
    cv2.imwrite('hires.png', img)
    

    which produces the following hires.png

    enter image description here