Search code examples
pythonarraysnumpyarray-broadcasting

Subtract Mean from Multidimensional Numpy-Array


I'm currently learning about broadcasting in Numpy and in the book I'm reading (Python for Data Analysis by Wes McKinney the author has mentioned the following example to "demean" a two-dimensional array:

import numpy as np

arr = np.random.randn(4, 3)
print(arr.mean(0))
demeaned = arr - arr.mean(0)
print(demeaned)
print(demeand.mean(0))

Which effectively causes the array demeaned to have a mean of 0.

I had the idea to apply this to an image-like, three-dimensional array:

import numpy as np

arr = np.random.randint(0, 256, (400,400,3))
demeaned = arr - arr.mean(2)

Which of course failed, because according to the broadcasting rule, the trailing dimensions have to match, and that's not the case here:

print(arr.shape)  # (400, 400, 3)
print(arr.mean(2).shape)  # (400, 400)

Now, i have gotten it to work mostly, by substracting the mean from every single index in the third dimension of the array:

demeaned = np.ones(arr.shape)

for i in range(3):
    demeaned[...,i] = arr[...,i] - means

print(demeaned.mean(0))

At this point, the returned values are very close to zero and i think, that's a precision error. Am i actually right with this thought or is there another caveat, that i missed?

Also, this doesn't seam to be the cleanest, most 'numpy'-way to achieve what i wanted to achieve. Is there a function or a principle that i can make use of to improve the code?


Solution

  • As of numpy version 1.7.0, np.mean, and several other functions, accept a tuple in their axis parameter. This means that you can perform the operation on the planes of the image all at once:

    m = arr.mean(axis=(0, 1))
    

    This mean will have shape (3,), with one element for each plane of the image.

    If you want to subtract the means of each pixel individually, you have to remember that broadcasting aligns shape tuples on the right edge. That means that you need to insert an extra dimension:

    n = arr.mean(axis=2)
    n = n.reshape(*n.shape, 1)
    

    Or

    n = arr.mean(axis=2)[..., None]