I'm encountering an error while using Numba-optimized functions to check if the extent of an n-dimensional box (n >= 1
) is larger than a minimum value along corresponding dimensions. The functions get_extent
and is_larger_than_min
are decorated with @njit
.
Here are the functions:
@njit
def get_extent(box):
return box[1] - box[0]
and
@njit
def is_larger_than_min(box, extent_min):
extent = get_extent(box)
return np.all(extent >= extent_min)
When I pass a 2D array box1
and its corresponding extent_min1
, everything works fine. However, when I pass a 1D array box2
and its extent_min2
, I encounter an error.
box1 = np.array([[0, 0, 0], [5, 5, 5]]) # shape (2, n), i.e. (n>1)-dimensional box
extent_min1 = np.array([4, 4, 4]) # shape (n,), i.e. extents along each dimension
box2 = np.array([0, 5]) # shape (2,), i.e. (n=1)-dimensional box, or just an interval
extent_min2 = 4 # scalar, i.e. extent (length) along this single dimension
is_larger_than_min(box1, extent_min1) # works fine
is_larger_than_min(box2, extent_min2) # raises error
The error message I receive is not very informative, but since the function is_larger_than_min
without @njit
works just fine with any types of inputs, it is obviously related to handling scalars and arrays differently within the Numba-optimized functions. How can I modify these functions to handle both scalars and arrays without encountering errors? Any insights or solutions would be greatly appreciated. Thanks!
You can modify your get_extent
function to always return an array with at least 1 dimension so that you always have an array passed to np.all
which is required in the numba implementation:
@njit
def get_extent(box):
return np.atleast_1d(np.asarray(box[1] - box[0]))