Search code examples
pythonnumpynumbajit

How to find extrema per cell in 3 dimensional array with Numba?


I have recently written a script to convert BGR arrays of [0, 1] floats to HSL and back. I posted it on Code Review. There is currently one answer but it doesn't improve performance.

I have benchmarked my code against cv2.cvtColor and found my code to be inefficient, so I want to compile the code with Numba to make it run faster.

I have tried to wrapping every function with @nb.njit(cache=True, fastmath=True), and this doesn't work.

So I have tested every NumPy syntax and NumPy functions I have used individually, and found two functions that don't work with Numba.

I need to find the maximum channel of each pixel (np.max(img, axis=-1)) and minimum channel of each pixel (np.max(img, axis=-1)), and the axis argument doesn't work with Numba.

I have tried to Google search this but the only thing even remotely relevant I found is this, but it only implements np.any and np.all, and only works for two dimensional arrays whereas here the arrays are three-dimensional.

I can write a for loop based solution but I won't write it, because it is bound to be inefficient and against the purpose of using NumPy and Numba in the first place.

Minimal reproducible example:

import numba as nb
import numpy as np

@nb.njit(cache=True, fastmath=True)
def max_per_cell(arr):
    return np.max(arr, axis=-1)

@nb.njit(cache=True, fastmath=True)
def min_per_cell(arr):
    return np.min(arr, axis=-1)

img = np.random.random((3, 4, 3))
max_per_cell(img)
min_per_cell(img)

Exception:

In [2]: max_per_cell(img)
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[2], line 1
----> 1 max_per_cell(img)

File C:\Python310\lib\site-packages\numba\core\dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File C:\Python310\lib\site-packages\numba\core\dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function amax at 0x0000014E306D3370>) found for signature:

 >>> amax(array(float64, 3d, C), axis=Literal[int](-1))

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'npy_max': File: numba\np\arraymath.py: Line 541.
    With argument(s): '(array(float64, 3d, C), axis=int64)':
   Rejected as the implementation raised a specific error:
     TypingError: got an unexpected keyword argument 'axis'
  raised from C:\Python310\lib\site-packages\numba\core\typing\templates.py:784

During: resolving callee type: Function(<function amax at 0x0000014E306D3370>)
During: typing of call at <ipython-input-1-b3894b8b12b8> (10)


File "<ipython-input-1-b3894b8b12b8>", line 10:
def max_per_cell(arr):
    return np.max(arr, axis=-1)
    ^

How to fix this?


Solution

  • It's reasonably straightforward to implement this without np.max(), using loops instead:

    @nb.njit()
    def max_per_cell_nb(arr):
        ret = np.empty(arr.shape[:-1], dtype=arr.dtype)
        n, m = ret.shape
        for i in range(n):
            for j in range(m):
                max_ = arr[i, j, 0]
                max_ = max(max_, arr[i, j, 1])
                max_ = max(max_, arr[i, j, 2])
                ret[i, j] = max_
        return ret
    

    Benchmarking it, it turns out to be about 16x faster than np.max(arr, axis=-1).

    %timeit max_per_cell_nb(img)
    4.88 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    %timeit max_per_cell(img)
    81 ms ± 654 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    While benchmarking this, I made the following assumptions:

    • The image is 1920x1080x3. (In other words, it's a big image.)
    • The image array is in C order rather than Fortran order. If it's in Fortran order, the speed of my method drops to 7ms, and the speed of np.max() gets faster and only takes 15 ms. See Check if numpy array is contiguous? for how to tell if your array is in C or Fortran order. Your example of np.random.random((3, 4, 3)) is C contiguous.
    • I'm comparing this function to np.max(arr, axis=-1) with Numba JIT turned off, because it can't really optimize single calls to NumPy functions.