Search code examples
pythonnumpycomplex-numbersmemory-efficientnumpy-ufunc

Most memory-efficient way to compute abs()**2 of complex numpy ndarray


I'm looking for the most memory-efficient way to compute the absolute squared value of a complex numpy ndarray

arr = np.empty((250000, 150), dtype='complex128')  # common size

I haven't found a ufunc that would do exactly np.abs()**2.

As an array of that size and type takes up around half a GB, I'm looking for a primarily memory-efficient way.

I would also like it to be portable, so ideally some combination of ufuncs.

So far my understanding is that this should be about the best

result = np.abs(arr)
result **= 2

It will needlessly compute (**0.5)**2, but should compute **2 in-place. Altogether the peak memory requirement is only the original array size + result array size, which should be 1.5 * original array size as the result is real.

If I wanted to get rid of the useless **2 call I'd have to do something like this

result = arr.real**2
result += arr.imag**2

but if I'm not mistaken, this means I'll have to allocate memory for both the real and imaginary part calculation, so the peak memory usage would be 2.0 * original array size. The arr.real properties also return a non-contiguous array (but that is of lesser concern).

Is there anything I'm missing? Are there any better ways to do this?

EDIT 1: I'm sorry for not making it clear, I don't want to overwrite arr, so I can't use it as out.


Solution

  • Thanks to numba.vectorize in recent versions of numba, creating a numpy universal function for the task is very easy:

    @numba.vectorize([numba.float64(numba.complex128),numba.float32(numba.complex64)])
    def abs2(x):
        return x.real**2 + x.imag**2
    

    On my machine, I find a threefold speedup compared to a pure-numpy version that creates intermediate arrays:

    >>> x = np.random.randn(10000).view('c16')
    >>> y = abs2(x)
    >>> np.all(y == x.real**2 + x.imag**2)   # exactly equal, being the same operation
    True
    >>> %timeit np.abs(x)**2
    10000 loops, best of 3: 81.4 µs per loop
    >>> %timeit x.real**2 + x.imag**2
    100000 loops, best of 3: 12.7 µs per loop
    >>> %timeit abs2(x)
    100000 loops, best of 3: 4.6 µs per loop