Search code examples
imagematlabimage-processingvectorizationmedian

2D median filter, ignore nan values


As a part of my project I need to use a code which performs median filtering on a rxr window, and ignores nan values.

I currently use MATLAB's nlfilter function. The problem is that it is extremely slow: it takes almost 5 seconds for 300x300 example, while MATLAB's medfilt2 takes 0.2 seconds. does anyone has a more efficient and elegant solution?

Note: the behavior on the boundaries of the isn't important in my case. In this example, nlfilter automatically pads the array with zeros, but other solutions such as boundary duplication are ok as well.

code example:

%Initialize input
r = 3; % window size is 3x3
I = [9,1,6,10,1,5,4;2,4,3,8,8,NaN,5;4,5,8,6,2,NaN,3;5,NaN,6,4,NaN,4,9;3,1,10,9,4,3,2;10,9,10,10,6,NaN,5;10,9,4,1,2,7,2];

%perform median filter on rxr window, igonre nans
f = @(A)median(A(~isnan(A)));
filteredRes = nlfilter(I, [r r], f);
filteredRes(nanMask) = nan;

Expected Results

before filtering:

I =
 9     1     6    10     1     5     4
 2     4     3     8     8   NaN     5
 4     5     8     6     2   NaN     3
 5   NaN     6     4   NaN     4     9
 3     1    10     9     4     3     2
10     9    10    10     6   NaN     5
10     9     4     1     2     7     2

after filtering:

filteredRes =
     0    2.0000    3.0000    3.0000    3.0000    2.5000         0
2.0000    4.0000    6.0000    6.0000    6.0000       NaN    3.0000
3.0000    4.5000    5.5000    6.0000    5.0000       NaN    3.0000
2.0000       NaN    6.0000    6.0000       NaN    3.0000    2.5000
2.0000    7.5000    9.0000    7.5000    4.0000    4.0000    2.5000
3.0000    9.0000    9.0000    6.0000    5.0000       NaN    2.0000
     0    9.0000    4.0000    2.0000    1.5000    2.0000         0

Thanks!


Solution

  • You can first pad the image using padarray where you want to pad floor(r/2) pixels on each side, then use im2col to restructure the padded image so that each neighbourhood of pixels is placed in separate columns. Next, you'll need to set all of the nan values to a dummy value first so you don't interfere with the median calculation... something like zero perhaps. After, find the median of each column, then reshape back into an image of the proper size.

    Something like this should work:

    r = 3;
    nanMask = isnan(I); % Define nan mask
    Ic = I;
    Ic(nanMask) = 0; % Create new copy of image and set nan elements to zero
    IP = padarray(Ic, floor([r/2 r/2]), 'both'); % Pad image
    IPc = im2col(IP, [r r], 'sliding'); % Transform into columns
    out = reshape(median(IPc, 1), size(I,1), size(I,2)); % Find median of each column and reshape back
    out(nanMask) = nan; % Set nan elements back
    

    We get:

    >> out
    
    out =
    
         0     2     3     3     1     1     0
         2     4     6     6     5   NaN     0
         2     4     5     6     4   NaN     0
         1   NaN     6     6   NaN     3     2
         1     6     9     6     4     4     2
         3     9     9     6     4   NaN     2
         0     9     4     2     1     2     0
    

    With the above approach, what's slightly different with your expected results is that we have set all nan values to 0 and these values are included in the median. In addition, should the number of elements be even in the median, then I simply chose the element to the right of the ambiguity as the final output.

    This may not be what you want specifically. A more valid approach would be to sort all of the columns individually while leaving the nan values intact, then determine the last element for each column that is valid and for each of these elements, determine where the halfway point is and choose those from the sorted columns. A benefit of using sort is that the nan values are pushed towards the end of the array.

    Something like this could work:

    r = 3;
    nanMask = isnan(I); % Define nan mask
    IP = padarray(I, floor([r/2 r/2]), 'both'); % Pad image
    IPc = im2col(IP, [r r], 'sliding'); % Transform into columns
    IPc = sort(IPc, 1, 'ascend'); % Sort the columns
    [~,ind] = max(isnan(IPc), [], 1); % For each column, find the last valid number
    ind(ind == 1) = r*r; % Handles the case when there are all valid numbers per column
    ind = ceil(ind / 2); % Find the halfway point
    out = reshape(IPc(sub2ind(size(IPc), ind, 1:size(IPc,2))), size(I,1), size(I,2)); % Find median of each column and reshape back
    out(nanMask) = nan; % Set nan elements back
    

    We now get:

    >> out
    
    out =
    
         0     2     3     3     5     4     0
         2     4     6     6     6   NaN     3
         4     5     6     6     6   NaN     3
         3   NaN     6     6   NaN     3     3
         3     9     9     9     4     4     3
         3     9     9     6     6   NaN     2
         0     9     4     2     2     2     0
    

    Minor Note

    Recent versions of MATLAB have an optional third input called nanflag where you can specifically determine what to do when nans are encountered. If you set the flag to omitnan this will ignore all nan elements in its calculation where the default is includenan where you don't have to specify a third parameter. If you specify omitnan in the median filter call as well as skip the setting of the nan values to 0 part in the first step, you'll get exactly what you wanted from the output of nlfilter:

    r = 3;
    nanMask = isnan(I); % Define nan mask
    IP = padarray(I, floor([r/2 r/2]), 'both'); % Pad image
    IPc = im2col(IP, [r r], 'sliding'); % Transform into columns
    out = reshape(median(IPc, 1, 'omitnan'), size(I,1), size(I,2)); % Find median of each column and reshape back
    out(nanMask) = nan; % Set nan elements back
    

    We get:

    >> out
    
    out =
    
             0    2.0000    3.0000    3.0000    3.0000    2.5000         0
        2.0000    4.0000    6.0000    6.0000    6.0000       NaN    3.0000
        3.0000    4.5000    5.5000    6.0000    5.0000       NaN    3.0000
        2.0000       NaN    6.0000    6.0000       NaN    3.0000    2.5000
        2.0000    7.5000    9.0000    7.5000    4.0000    4.0000    2.5000
        3.0000    9.0000    9.0000    6.0000    5.0000       NaN    2.0000
             0    9.0000    4.0000    2.0000    1.5000    2.0000         0
    

    A more efficient im2col solution

    User Divakar has implemented a more faster version of im2col which he has benchmarked and is shown to be a lot faster than the im2col solution provided by MATLAB's image processing toolbox. If you're going to call this code many many times, consider using his implementation: Efficient Implementation of `im2col` and `col2im`

    Timing Tests

    To determine if the proposed approach is faster, I'll perform a timing test with timeit. Firstly, I'm going to create a function that sets up common variables, creates two functions where the first is the original method with nlfilter and the second method is with the proposed approach. I'm going to be using the method using 'omitnan' as that produces exactly what you want as a result.

    Here's the function I wrote. I've generated an input of 300 x 300 like how you've set yours up and it contains all random numbers between 0 and 1. I've made it so that approximately 20% of the numbers in this input have nan. I also set up the anonymous function you're using with nlfilter to filter the medians without nans as well as the neighbourhood size, which is 3 x 3. I then define two functions within this code - the original method where the code does the filtering with nlfilter and what I proposed above with the omitnan option:

    function time_nan
    
    % Initial setup
    rng(112234);
    I = rand(300,300);
    I(I < 0.2) = nan; % Modify approximately 20% of the values in the input with nan
    r = 3; % Median filter of size 3
    nanMask = isnan(I); % Determine which locations are nan
    f = @(A)median(A(~isnan(A))); % Handle to function used by nlfilter
    
        function original_method
            filteredRes = nlfilter(I, [r r], f);
            filteredRes(nanMask) = nan;
        end
    
        function new_method
            IP = padarray(I, floor([r/2 r/2]), 'both'); % Pad image
            IPc = im2col(IP, [r r], 'sliding'); % Transform into columns
            out = reshape(median(IPc, 1, 'omitnan'), size(I,1), size(I,2)); % Find median of each column and reshape back
            out(nanMask) = nan; % Set nan elements back
        end
    
    t1 = timeit(@original_method);
    t2 = timeit(@new_method);
    
    fprintf('Average time for original method: %f seconds\n', t1);
    fprintf('Average time for new method: %f seconds\n', t2);
    
    end
    

    My current machine is a HP ZBook G5 with 16 GB of RAM and an Intel Core i7 @ 2.80 GHz. When you run this code, I get the following result:

    Average time for original method: 1.033838 seconds
    Average time for new method: 0.038697 seconds
    

    As you can see, the new method runs approximately (1.033838 / 0.038697) = 26.7162x faster than nlfilter. Not bad!