Search code examples
pythonimage-processingnumpycamerascipy

Automatically remove hot/dead pixels from an image in python


I am using numpy and scipy to process a number of images taken with a CCD camera. These images have a number of hot (and dead) pixels with very large (or small) values. These interfere with other image processing, so they need to be removed. Unfortunately, though a few of the pixels are stuck at either 0 or 255 and are always at the same value in all of the images, there are some pixels that are temporarily stuck at other values for a period of a few minutes (the data spans many hours).

I am wondering if there is a method for identifying (and removing) the hot pixels already implemented in python. If not, I am wondering what would be an efficient method for doing so. The hot/dead pixels are relatively easy to identify by comparing them with neighboring pixels. I could see writing a loop that looks at each pixel, compares its value to that of its 8 nearest neighbors. Or, it seems nicer to use some kind of convolution to produce a smoother image and then subtract this from the image containing the hot pixels, making them easier to identify.

I have tried this "blurring method" in the code below, and it works okay, but I doubt that it is the fastest. Also, it gets confused at the edge of the image (probably since the gaussian_filter function is taking a convolution and the convolution gets weird near the edge). So, is there a better way to go about this?

Example code:

import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage

plt.figure(figsize=(8,4))
ax1 = plt.subplot(121)
ax2 = plt.subplot(122)

#make a sample image
x = np.linspace(-5,5,200)
X,Y = np.meshgrid(x,x)
Z = 255*np.cos(np.sqrt(x**2 + Y**2))**2


for i in range(0,11):
    #Add some hot pixels
    Z[np.random.randint(low=0,high=199),np.random.randint(low=0,high=199)]= np.random.randint(low=200,high=255)
    #and dead pixels
    Z[np.random.randint(low=0,high=199),np.random.randint(low=0,high=199)]= np.random.randint(low=0,high=10)

#Then plot it
ax1.set_title('Raw data with hot pixels')
ax1.imshow(Z,interpolation='nearest',origin='lower')

#Now we try to find the hot pixels
blurred_Z = scipy.ndimage.gaussian_filter(Z, sigma=2)
difference = Z - blurred_Z

ax2.set_title('Difference with hot pixels identified')
ax2.imshow(difference,interpolation='nearest',origin='lower')

threshold = 15
hot_pixels = np.nonzero((difference>threshold) | (difference<-threshold))

#Don't include the hot pixels that we found near the edge:
count = 0
for y,x in zip(hot_pixels[0],hot_pixels[1]):
    if (x != 0) and (x != 199) and (y != 0) and (y != 199):
        ax2.plot(x,y,'ro')
        count += 1

print 'Detected %i hot/dead pixels out of 20.'%count
ax2.set_xlim(0,200); ax2.set_ylim(0,200)


plt.show()

And the output: enter image description here


Solution

  • Basically, I think that the fastest way to deal with hot pixels is just to use a size=2 median filter. Then, poof, your hot pixels are gone and you also kill all sorts of other high-frequency sensor noise from your camera.

    If you really want to remove ONLY the hot pixels, then substituting you can subtract the median filter from the original image, as I did in the question, and replace only these values with the values from the median filtered image. This doesn't work well at the edges, so if you can ignore the pixels along the edge, then this will make things a lot easier.

    If you want to deal with the edges, you can use the code below. However, it is not the fastest:

    import numpy as np
    import matplotlib.pyplot as plt
    import scipy.ndimage
    
    plt.figure(figsize=(10,5))
    ax1 = plt.subplot(121)
    ax2 = plt.subplot(122)
    
    #make some sample data
    x = np.linspace(-5,5,200)
    X,Y = np.meshgrid(x,x)
    Z = 100*np.cos(np.sqrt(x**2 + Y**2))**2 + 50
    
    np.random.seed(1)
    for i in range(0,11):
        #Add some hot pixels
        Z[np.random.randint(low=0,high=199),np.random.randint(low=0,high=199)]= np.random.randint(low=200,high=255)
        #and dead pixels
        Z[np.random.randint(low=0,high=199),np.random.randint(low=0,high=199)]= np.random.randint(low=0,high=10)
    
    #And some hot pixels in the corners and edges
    Z[0,0]   =255
    Z[-1,-1] =255
    Z[-1,0]  =255
    Z[0,-1]  =255
    Z[0,100] =255
    Z[-1,100]=255
    Z[100,0] =255
    Z[100,-1]=255
    
    #Then plot it
    ax1.set_title('Raw data with hot pixels')
    ax1.imshow(Z,interpolation='nearest',origin='lower')
    
    def find_outlier_pixels(data,tolerance=3,worry_about_edges=True):
        #This function finds the hot or dead pixels in a 2D dataset. 
        #tolerance is the number of standard deviations used to cutoff the hot pixels
        #If you want to ignore the edges and greatly speed up the code, then set
        #worry_about_edges to False.
        #
        #The function returns a list of hot pixels and also an image with with hot pixels removed
    
        from scipy.ndimage import median_filter
        blurred = median_filter(Z, size=2)
        difference = data - blurred
        threshold = tolerance*np.std(difference)
    
        #find the hot pixels, but ignore the edges
        hot_pixels = np.nonzero((np.abs(difference[1:-1,1:-1])>threshold) )
        hot_pixels = np.array(hot_pixels) + 1 #because we ignored the first row and first column
    
        fixed_image = np.copy(data) #This is the image with the hot pixels removed
        for y,x in zip(hot_pixels[0],hot_pixels[1]):
            fixed_image[y,x]=blurred[y,x]
            
        if worry_about_edges == True:
            height,width = np.shape(data)
        
            ###Now get the pixels on the edges (but not the corners)###
    
            #left and right sides
            for index in range(1,height-1):
                #left side:
                med  = np.median(data[index-1:index+2,0:2])
                diff = np.abs(data[index,0] - med)
                if diff>threshold: 
                    hot_pixels = np.hstack(( hot_pixels, [[index],[0]]  ))
                    fixed_image[index,0] = med
                
                #right side:
                med  = np.median(data[index-1:index+2,-2:])
                diff = np.abs(data[index,-1] - med)
                if diff>threshold: 
                    hot_pixels = np.hstack(( hot_pixels, [[index],[width-1]]  ))
                    fixed_image[index,-1] = med
    
            #Then the top and bottom
            for index in range(1,width-1):
                #bottom:
                med  = np.median(data[0:2,index-1:index+2])
                diff = np.abs(data[0,index] - med)
                if diff>threshold: 
                    hot_pixels = np.hstack(( hot_pixels, [[0],[index]]  ))
                    fixed_image[0,index] = med
                
                #top:
                med  = np.median(data[-2:,index-1:index+2])
                diff = np.abs(data[-1,index] - med)
                if diff>threshold: 
                    hot_pixels = np.hstack(( hot_pixels, [[height-1],[index]]  ))
                    fixed_image[-1,index] = med
                      
            ###Then the corners###
    
            #bottom left
            med  = np.median(data[0:2,0:2])
            diff = np.abs(data[0,0] - med)
            if diff>threshold: 
                hot_pixels = np.hstack(( hot_pixels, [[0],[0]]  ))
                fixed_image[0,0] = med
            
            #bottom right
            med  = np.median(data[0:2,-2:])
            diff = np.abs(data[0,-1] - med)
            if diff>threshold: 
                hot_pixels = np.hstack(( hot_pixels, [[0],[width-1]]  ))
                fixed_image[0,-1] = med
    
            #top left
            med  = np.median(data[-2:,0:2])
            diff = np.abs(data[-1,0] - med)
            if diff>threshold: 
                hot_pixels = np.hstack(( hot_pixels, [[height-1],[0]]  ))
                fixed_image[-1,0] = med
    
            #top right
            med  = np.median(data[-2:,-2:])
            diff = np.abs(data[-1,-1] - med)
            if diff>threshold: 
                hot_pixels = np.hstack(( hot_pixels, [[height-1],[width-1]]  ))
                fixed_image[-1,-1] = med
        
        return hot_pixels,fixed_image
    
    
    hot_pixels,fixed_image = find_outlier_pixels(Z)
    
    for y,x in zip(hot_pixels[0],hot_pixels[1]):
        ax1.plot(x,y,'ro',mfc='none',mec='r',ms=10)
    
    ax1.set_xlim(0,200)
    ax1.set_ylim(0,200)
    
    ax2.set_title('Image with hot pixels removed')
    ax2.imshow(fixed_image,interpolation='nearest',origin='lower',clim=(0,255))
    
    plt.show()
    

    Output: enter image description here