Search code examples
pythonarraysnumpynumba

Numba Function Error: Handling 1D and 2D Arrays Differently


I'm encountering an error while using Numba-optimized functions to check if the extent of an n-dimensional box (n >= 1) is larger than a minimum value along corresponding dimensions. The functions get_extent and is_larger_than_min are decorated with @njit.

Here are the functions:

@njit
def get_extent(box):
    return box[1] - box[0]

and

@njit
def is_larger_than_min(box, extent_min):
    extent = get_extent(box)
    return np.all(extent >= extent_min)

When I pass a 2D array box1 and its corresponding extent_min1, everything works fine. However, when I pass a 1D array box2 and its extent_min2, I encounter an error.

box1 = np.array([[0, 0, 0], [5, 5, 5]])   # shape (2, n), i.e. (n>1)-dimensional box
extent_min1 = np.array([4, 4, 4])   # shape (n,), i.e. extents along each dimension

box2 = np.array([0, 5])   # shape (2,), i.e. (n=1)-dimensional box, or just an interval
extent_min2 = 4   # scalar, i.e. extent (length) along this single dimension

is_larger_than_min(box1, extent_min1)   # works fine
is_larger_than_min(box2, extent_min2)   # raises error

The error message I receive is not very informative, but since the function is_larger_than_min without @njit works just fine with any types of inputs, it is obviously related to handling scalars and arrays differently within the Numba-optimized functions. How can I modify these functions to handle both scalars and arrays without encountering errors? Any insights or solutions would be greatly appreciated. Thanks!


Solution

  • You can modify your get_extent function to always return an array with at least 1 dimension so that you always have an array passed to np.all which is required in the numba implementation:

    @njit
    def get_extent(box):
        return np.atleast_1d(np.asarray(box[1] - box[0]))