Search code examples
pythonnumpyopencvcomputer-visionfft

How to properly calculate PSD plot (Power Spectrum Density Plot) for images in order to remove periodic noise?


Im trying to remove periodic noise from an image using PSDP, I had some success, but Im not sure if what Im doing is 100% correct.
This is basically a kind of follow up to this video lecture which discusses this very subject on 1d signals.

What I have done so far:

  1. Initially I tried flattening the whole image, and then treating it as a 1D signal, this obviously gives me a plot, but the plot doesn't look right honestly and the final result is not that appealing.

This is the first try:

# img link https://github.com/VladKarpushin/Periodic-noise-removing-filter/blob/master/www/images/period_input.jpg?raw=true
img = cv2.imread('./img/periodic_noisy_image2.jpg',0)
img_flattened = img.flatten()
n = img_flattened.shape[0] # 447561
fft = np.fft.fft(img_flattened, img_flattened.shape[0])

# the values range is just absurdly large, so 
# we have to use log at some point to get the
# values range to become sensible!
psd = fft*np.conj(fft)/n
freq = 1/n * np.arange(n)
L = np.arange(1,np.floor(n/2),dtype='int')

# use log so we have a sensible range!
psd_log = np.log(psd)
print(f'{psd_log.min()=} {psd_log.max()=}')
# cut off range to remove noise!
indexes = psd_log<15
# use exp to get the original vlaues for plotting comparison
psd_cleaned = np.exp(psd_log * indexes)
# get the denoised fft
fft_cleaned = fft * indexes

# in case the initial parts were affected, 
# lets restore it from fft so the final image looks well
span = 10
fft_cleaned[:span] = fft[:span]

# get back the image
denoised_img = np.fft.ifftn(fft_cleaned).real.clip(0,255).astype(np.uint8).reshape(img.shape)

plt.subplot(2,2,1), plt.imshow(img,cmap='gray'), plt.title('original image')
plt.subplot(2,2,2), plt.imshow(denoised_img, cmap='gray'), plt.title('denoise image')
plt.subplot(2,2,3), plt.plot(freq[L],psd[L]), plt.title('PSD')
plt.subplot(2,2,4), plt.plot(freq[L],psd_cleaned[L]), plt.title('PSD clean')
plt.show()

This is the output, the image is denoised a bit, but overall, it doesnt sit right with me, as I assume I should at least get as good a result as my second attempt, the plots also look weird!
enter image description here

  1. in my second attempt, I simply calculated the power spectrum the normal way, and got a much better result imho!:
# Read the image in grayscale
img = cv2.imread('./img/periodic_noisy_image2.jpg', 0)

# Perform 2D Fourier transform
fft = np.fft.fftn(img)
fft_shift = np.fft.fftshift(fft)

# Calculate Power Spectrum Density, its the same as doing fft_shift*np.conj(fft_shift)
# note that abs(fft_shitf) calculates square root of powerspectrum, so we **2 it to get the actual power spectrum!
# but we still need to divide it by the frequency to get the plot (for visualization only)!
# this is what we do next!
# I use log to make large numbers smaller and small numbers larger so they show up properly in visualization
psd = np.log(np.abs(fft_shift)**2)

# now I can filter out the bright spots which signal noise
# take the indexes belonging to these large values
# and then use tha to set them in the actual fft to 0 to suppress them
# 20-22 image gets too smoothed out, and >24, its still visibly noisy
ind = psd<23
psd2 = psd*ind
fft_shift2 = ind * fft_shift
# since this is not accurate, we may very well endup destroying 
# the center of the fft which contains low freq important image information
# (it has large values there as well) so we grab that area from fft and copy
# it back to restore the lost values this way!
cx,cy = img.shape[0]//2, img.shape[1]//2
area = 20
# restore the center in case it was overwritten!
fft_shift2[cx-area:cx+area,cy-area:cy+area] = fft_shift[cx-area:cx+area,cy-area:cy+area]

ifft_shift2 = np.fft.ifftshift(fft_shift2)
denoised_img = np.fft.ifftn(ifft_shift2).real.clip(0,255).astype(np.uint8)

# Get frequencies for each dimension
freq_x = np.fft.fftfreq(img.shape[0])
freq_y = np.fft.fftfreq(img.shape[1])

# Create a meshgrid of frequencies
freq_x, freq_y = np.meshgrid(freq_x, freq_y)

# Plot the PSD
plt.figure(figsize=(10, 7))
plt.subplot(2,2,1), plt.imshow(img, cmap='gray'), plt.title('img')
plt.subplot(2,2,2), plt.imshow(denoised_img, cmap='gray'), plt.title('denoised image')
#plt.subplot(2,2,3), plt.imshow(((1-ind)*255)), plt.title('mask-inv')
plt.subplot(2,2,3), plt.imshow(psd2, extent=(np.min(freq_x), np.max(freq_x), np.min(freq_y), np.max(freq_y))), plt.title('Power Spectrum Density[cleaned]')
plt.subplot(2,2,4), plt.imshow(psd, extent=(np.min(freq_x), np.max(freq_x), np.min(freq_y), np.max(freq_y))),plt.title('Power Spectrum Density[default]')
plt.xlabel('Frequency (X)')
plt.ylabel('Frequency (Y)')
plt.colorbar()
plt.show()

enter image description here
enter image description here

This seems to work, but I'm not getting a good result, I'm not sure if I am doing something wrong here, or this is the best that can be achieved!

  1. What I did next was, I tried to completely set a rectangle around all the bright spots and set them all to zeros, this way we I make sure the surrounding values are also taken care of as much as possible and this is what I get as the output:

img = cv2.imread('./img/periodic_noisy_image2.jpg')
while (True):
    # calculate the dft
    ffts = np.fft.fftn(img)
    # now shift to center for better interpretation
    ffts_shifted = np.fft.fftshift(ffts) 
    # power spectrum
    ffts_shifted_mag = (20*np.log(np.abs(ffts_shifted))).astype(np.uint8)
    # use selectROI to select the spots we want to set to 0!
    noise_rois = cv2.selectROIs('select periodic noise spots(press Spc to take selection, press esc to end selection)', ffts_shifted_mag,False, False,False)
    print(f'{noise_rois=}')
    # now set the area in fft_shifted to zero 
    for y,x,h,w in noise_rois:
        # we need to provide a complex number!
        ffts_shifted[x:x+w,y:y+h] = 0+0j

    # shift back
    iffts_shifted = np.fft.ifftshift(ffts_shifted)
    iffts = np.fft.ifftn(iffts_shifted)

    # getback the image
    img_denoised = iffts.real.clip(0,255).astype(np.uint8)

    # lets calculate the new image magnitude
    denoise_ffts = np.fft.fftn(img_denoised)
    denoise_ffts_shifted = np.fft.fftshift(denoise_ffts)
    denoise_mag = (20*np.log(np.abs(denoise_ffts_shifted))).astype(np.uint8)

    cv2.imshow('img-with-periodic-noise', img)
    cv2.imshow('ffts_shifted_mag', ffts_shifted_mag)
    cv2.imshow('denoise_mag',denoise_mag)
    cv2.imshow('img_denoised', img_denoised)
    # note we are using 0 so it only goes next when we press it, otherwise we cant see the result!
    key = cv2.waitKey(0)&0xFF
    cv2.destroyAllWindows()

    if key == ord('q'):
        break

enter image description here enter image description here enter image description here

Again I had the assumption, by removing these periodic noise, the image would look much better, but I still can see patterns which means they are not removed completely! but at the same time, I did remove the bright spots!

This gets even harder (so far impossible) to get this image denoised using this method:
enter image description here

This is clearly a periodic noise, so what is it that I'm missing or doing wrong here?

For the reference this is the other image with periodic noise which I have been experimenting with:
enter image description here

Update :

After reading the comments and suggestions so far, I came up with the following version, which overall works decently, but I face these issues:

  1. I dont get tiny imaginary values, even when the output looks fairly good! I cant seem to rely on this check to see what has gone wrong, it exists when there are very little/barely noticeable noise, and when there are noise everywhere!
  2. Still face a considerable amount of noise in some images (example given) I'd be great to know if this is expected and I should move on, or there's something wrong which needs to be addressed.
def onchange(x):pass
cv2.namedWindow('options')
cv2.createTrackbar('threshold', 'options', 130, 255, onchange)
cv2.createTrackbar('area', 'options', 40, max(*img.shape[:2]), onchange)
cv2.createTrackbar('pad', 'options', 0, max(*img.shape[:2]), onchange)
cv2.createTrackbar('normalize_output', 'options', 0, 1, onchange)

while(True):

    threshold = cv2.getTrackbarPos('threshold', 'options')
    area = cv2.getTrackbarPos('area', 'options')
    pad = cv2.getTrackbarPos('pad', 'options')
    normalize_output = cv2.getTrackbarPos('normalize_output', 'options')
    
    input_img = cv2.copyMakeBorder(img, pad, pad, pad, pad, cv2.BORDER_REFLECT) if pad>0 else img
    
    fft = np.fft.fftn(input_img)
    fft_shift = np.fft.fftshift(fft)
    # note since we plan on normalizing the magnitude spectrum,
    # we dont clip and we dont cast here!
    # +1 so for the images that have 0s we dont get -inf down the road and dont face issues when we want to normalize and create a mask out of it!
    fft_shift_mag = 20*np.log(np.abs(fft_shift)+1)
    # now lets normalize and get a mask out of it, 
    # the idea is to identify bright spot and set them to 0
    # while retaining the center of the fft as it has a lot
    # of image information 
    fft_shift_mag_norm = cv2.normalize(fft_shift_mag, None, 0,255, cv2.NORM_MINMAX)
    # now lets threshold and get our mask
    if img.ndim>2:
        mask = np.array([cv2.threshold(fft_shift_mag_norm[...,i], threshold, 255, cv2.THRESH_BINARY)[1] for i in range(3)])
        # the mask/img needs to be contiguous, (a simple .copy() would work as well!)
        mask = np.ascontiguousarray(mask.transpose((1,2,0)))
    else:
        ret, mask = cv2.threshold(fft_shift_mag_norm, threshold, 255, cv2.THRESH_BINARY)
        
    w,h = input_img.shape[:2]
    cx,cy = w//2, h//2
    mask = cv2.circle(mask, (cy,cx), radius=area, color=0, thickness=cv2.FILLED)

    # now that we have our mask prepared, we can simply use it with the actual fft to 
    # set all these bright places to 0
    fft_shift[mask!=0] = 0+0j
    ifft_shift = np.fft.ifftshift(fft_shift)
    img_denoised = np.fft.ifftn(ifft_shift).real.clip(0,255).astype(np.uint8)
    img_denoised = img_denoised[pad:w-pad,pad:h-pad]
    
    # check the ifft imaginary parts are close to zero otherwise sth is wrong!
    almost_zero = np.all(np.isclose(ifft_shift.imag,0,atol=1e-8))
    if not almost_zero:
        print('imaginary components not close to 0, something is wrong!')
    else:
        print(f'all is good!')
        
    # do a final contrast stretching: 
    if normalize_output:
        p2, p98 = np.percentile(img_denoised, (2, 98))
        img_denoised = img_denoised.clip(p2, p98)
        img_denoised = cv2.normalize(img_denoised, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
        
    cv2.imshow('input_img', input_img)
    cv2.imshow('fft-shift-mag-norm', fft_shift_mag_norm)
    cv2.imshow('fft_shift_mag', ((fft_shift_mag.real/fft_shift_mag.real.max())*255).clip(0,255).astype(np.uint8))
    cv2.imshow('mask', mask)
    cv2.imshow('denoised', img_denoised)

    key = cv2.waitKey(30)&0xFF
    if key == ord('q') or key == 27:
        cv2.destroyAllWindows()
        break

relatively good output:
enter image description here
enter image description here

Not so much! This is the one example I still get lots of noise. I'm not sure if this is the best I can expect, or there is still room for improvements:
enter image description here enter image description here enter image description here

There are other samples, where I couldn't remove all the noise either, such as this one(I could tweak it a bit but there would still be artifacts):
enter image description here enter image description here

I attributed this to the low quality of the image itself and accepted it, However, I expected the second example to have room for improvements, I thought I should be able to ultimately get something like this or close to it: enter image description here

  • Are my assumptions incorrect?
  • Are these artifacts/noises we are seeing in the outputs, periodic noise or some other types of noise?
  • Relatively speaking, Is this the best one can achieve/hope for when using this method? I mean by purely removing periodic noise and not resorting to anything advanced?

Solution

  • Here are some things you can do to improve your results:

    1. The hard transition from 1 to 0 in your frequency-domain kernel (ind in the 2nd block of code, it is implicit in the 3rd) means that you’ll get lots of ringing artifacts back in the spatial domain. This is 99% of the strange stuff in your output.

      To see this ringing, try contrast-stretching the output instead of clipping (clipping is correct, but the alternative method shows you all the artifacts you’re clipping away). plt.imshow will show you the contrast-stretched image if you leave it as a floating-point array. [I.e. just do plt.imshow(np.fft.ifftn(ifft_shift2).real).

      You could also inverse-transform the kernel ind. You’ll see it has a very large extent and does a lot of ringing.

      The better approach to create the frequency-domain kernel is to draw Gaussian-shaped blobs, or in some other way taper the edges of the squares you draw in the 3rd code block. One easy way to draw rectangles with tapered edges is to use the function dip.DrawBandlimitedBox in DIPlib (disclaimer: I’m an author). I’m not sure if there are other image processing libraries with an equivalent function.

    2. Handle edge effects. These are not very visible yet, but once you take care of #1, they’ll become more apparent. This is not easy in this application, because the noise pattern has to be continued at the image edge in a different way from the signal. See this Q&A for an example.

    3. Also, do note that the frequency-domain kernel you construct must be perfectly symmetric around the origin. For every box you draw on the left half of the image, you need to draw a box on the right side at exactly the same location (mirror the coordinates both horizontally and vertically). Verify that the imaginary component of the inverse transform is approximately 0, if the boxes are not perfectly symmetric it won’t be. When the kernel not perfectly symmetric, you’ll discard some of the signal when you take the real part of the inverse transform, and this discarded signal has a pattern of its own…

    4. There are more strong dots at higher frequencies from the ones you are removing. Removing these will further improve the results. Alternatively, use a low-pass filter that removes all of the frequencies at the dots and higher (draw a disk with tapered edges around the origin in the frequency domain). This would match what we see when we look at the image from a bit of a distance.


    Here's how I would implement this using DIPlib:

    import diplib as dip
    
    img = dip.ImageRead("7LOLyaeK.jpg")  # the soldier image
    
    # Fourier transform
    F = dip.FourierTransform(img)
    F.ResetPixelSize()  # so we can see coordinates in pixels
    dip.viewer.ShowModal(F)  # click on "MAG" (in 3rd column) and "LOG" (in 2nd column)
    # I see peaks at the following locations (one of each pair of peaks):
    pos = [
        (513, 103),
        (655, 170),
        (799, 236),
        (654, 303),
    ]
    # Let's ignore all the other peaks for now, though we should take care of them too
    
    # Maks out peaks
    mask = F.Similar("SFLOAT")
    mask.Fill(1)
    origin = (mask.Size(0) // 2, mask.Size(1) // 2)
    sigma = 5
    value = 2 * 3.14159 * sigma**2  # we need to undo the normalization in dip.DrawBandlimitedPoint()
    for p in pos:
        dip.DrawBandlimitedPoint(mask, origin=p, value=-value, sigmas=sigma)
        p = (2 * origin[0] - p[0], 2 * origin[1] - p[1])
        dip.DrawBandlimitedPoint(mask, origin=p, value=-value, sigmas=sigma)
    
    dip.viewer.ShowModal(mask)
    
    # Apply the filter and inverse transform
    out = dip.InverseFourierTransform(F * mask, {"real"})
    dip.viewer.Show(img)
    dip.viewer.Show(out)
    dip.viewer.Spin()
    

    This doesn't look very good because we didn't take care of all the other peaks. The dithering pattern is not just four sine waves, it's quite a bit more complex than that. But we don't actually expect there to be any frequencies in the image above that of the dithering pattern. So, you're actually better off simply applying a low-pass filter in this case:

    out2 = dip.Gauss(img, 2.1)  # Finding the best cutoff is a bit of a trial-and-error
    dip.viewer.Show(img)
    dip.viewer.Show(out)
    dip.viewer.Show(out2)
    dip.viewer.Spin()
    

    The Veritasium image is a big old mess, looking at the Fourier transform, there's just not a whole lot left that we can recover. Again, applying a low-pass filter gives you a lower bound on what you could potentially accomplish with a linear filter.