Search code examples
pythonarraysnumpyarray-broadcasting

Numpy 3d array - normalize rows


I've been working on normalizing image data with horizontal dark/bright lines.

Please find my minimal working example below. Let there be a 3d array (three images):

import numpy as np
a= np.arange(3)#
a= np.vstack([a,a+3,a+6])
a= np.repeat(a,3,axis=0)#.reshape((3,3,3))
a = a.reshape(3,3,3)

For instance, image #2 a[:,:,1] consists of

[[1 1 1]
 [4 4 4]
 [7 7 7]]

visual representation of Image #2

and shows horizontal streaks. To get rid of the streaks I am calculating the median of the rows, the overall median and a difference matrix.

row_median = np.median(a, axis=1)
bg_median= np.tile(np.median(a, axis=[0,1]),(3,1))
difference_matrix= bg_median-row_median

Subsequently, I am iterating through the images and apply the difference matrix to all images.

for i in range(len(a)):
    a[:,:,i] = a[:,:,i] + np.tile(difference_matrix[:,i],(3,1)).T

This gives the desired result, e.g., in a[:,:,1] :

[[4 4 4]
 [4 4 4]
 [4 4 4]]

This procedure is very slow for large images and stacks of images. I would appreciate any comments and hints for improving the performance of my code, possible through the use of broadcasting.

EDIT: Following up on Divakar's answer

a = a.astype('float64')
diff = np.median(a, axis=[0,1]) - np.median(a, axis=0))
a += diff[None,:]

solved the issue for me.


Solution

  • Approach #1

    Leverage broadcasting upon extending dimensions with None/np.newaxis instead of tiling those intermediate arrays, to save on memory and hence to achieve perf. effciency. Hence, the changes would be -

    diff = np.median(a, axis=[0,1]) - np.median(a, axis=1)
    a += diff[:,None]
    

    This takes care of the dimensionality extension under the hoods.

    Approach #2

    Alternatively, a more explicit way to keep track of dims, would be to keep them while performing data reductio and thus avoid the final dim extension with None. So, we can use keepdims as True with them -

    diff = np.median(a,axis=(0,1),keepdims=True) - np.median(a, axis=1,keepdims=True)
    a += diff