pythonnumpynumbajit

# How to find extrema per cell in 3 dimensional array with Numba?

I have recently written a script to convert BGR arrays of [0, 1] floats to HSL and back. I posted it on Code Review. There is currently one answer but it doesn't improve performance.

I have benchmarked my code against `cv2.cvtColor` and found my code to be inefficient, so I want to compile the code with Numba to make it run faster.

I have tried to wrapping every function with `@nb.njit(cache=True, fastmath=True)`, and this doesn't work.

So I have tested every NumPy syntax and NumPy functions I have used individually, and found two functions that don't work with Numba.

I need to find the maximum channel of each pixel (`np.max(img, axis=-1)`) and minimum channel of each pixel (`np.max(img, axis=-1)`), and the `axis` argument doesn't work with Numba.

I have tried to Google search this but the only thing even remotely relevant I found is this, but it only implements `np.any` and `np.all`, and only works for two dimensional arrays whereas here the arrays are three-dimensional.

I can write a for loop based solution but I won't write it, because it is bound to be inefficient and against the purpose of using NumPy and Numba in the first place.

Minimal reproducible example:

``````import numba as nb
import numpy as np

@nb.njit(cache=True, fastmath=True)
def max_per_cell(arr):
return np.max(arr, axis=-1)

@nb.njit(cache=True, fastmath=True)
def min_per_cell(arr):
return np.min(arr, axis=-1)

img = np.random.random((3, 4, 3))
max_per_cell(img)
min_per_cell(img)
``````

Exception:

``````In [2]: max_per_cell(img)
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[2], line 1
----> 1 max_per_cell(img)

File C:\Python310\lib\site-packages\numba\core\dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
465                f"by the following argument(s):\n{args_str}\n")
466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470     # Something unsupported is present in the user code, add help info
471     error_rewrite(e, 'unsupported_error')

File C:\Python310\lib\site-packages\numba\core\dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407     raise e
408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function amax at 0x0000014E306D3370>) found for signature:

>>> amax(array(float64, 3d, C), axis=Literal[int](-1))

There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'npy_max': File: numba\np\arraymath.py: Line 541.
With argument(s): '(array(float64, 3d, C), axis=int64)':
Rejected as the implementation raised a specific error:
TypingError: got an unexpected keyword argument 'axis'
raised from C:\Python310\lib\site-packages\numba\core\typing\templates.py:784

During: resolving callee type: Function(<function amax at 0x0000014E306D3370>)
During: typing of call at <ipython-input-1-b3894b8b12b8> (10)

File "<ipython-input-1-b3894b8b12b8>", line 10:
def max_per_cell(arr):
return np.max(arr, axis=-1)
^
``````

How to fix this?

Solution

• It's reasonably straightforward to implement this without `np.max()`, using loops instead:

``````@nb.njit()
def max_per_cell_nb(arr):
ret = np.empty(arr.shape[:-1], dtype=arr.dtype)
n, m = ret.shape
for i in range(n):
for j in range(m):
max_ = arr[i, j, 0]
max_ = max(max_, arr[i, j, 1])
max_ = max(max_, arr[i, j, 2])
ret[i, j] = max_
return ret
``````

Benchmarking it, it turns out to be about 16x faster than `np.max(arr, axis=-1)`.

``````%timeit max_per_cell_nb(img)
4.88 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit max_per_cell(img)
81 ms ± 654 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
``````

While benchmarking this, I made the following assumptions:

• The image is 1920x1080x3. (In other words, it's a big image.)
• The image array is in C order rather than Fortran order. If it's in Fortran order, the speed of my method drops to 7ms, and the speed of `np.max()` gets faster and only takes 15 ms. See Check if numpy array is contiguous? for how to tell if your array is in C or Fortran order. Your example of `np.random.random((3, 4, 3))` is C contiguous.
• I'm comparing this function to `np.max(arr, axis=-1)` with Numba JIT turned off, because it can't really optimize single calls to NumPy functions.