I am attempting to compile what is effectively a piecewise function using numba.njit
. The Python function is defined as follows, using Numpy:
(for anyone interested in the Sympy
origins of this issue, see the Notes below.)
Minimal Example
from numpy import select, less, nan
def f(t):
condlist = [less(t, 5), less(t, 15), less(t, 20), True]
choicelist = [1, 0, 1, 0]
return select(condlist, choicelist, default=nan)
See below for confirmation that this function works in Python.
The Issue: However, Numba fails to JIT this function in nopython
mode:
from numba import njit
jit_f = njit(f)
x = np.linspace(0, 50, 500)
jit_f(x)
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
Input In [86], in <cell line: 5>()
2 jit_f = njit(f)
4 x = np.linspace(0,50,500)
----> 5 jit_f(x)
File /usr/local/lib/python3.8/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464 msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
465 f"by the following argument(s):\n{args_str}\n")
466 e.patch_message(msg)
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
471 error_rewrite(e, 'unsupported_error')
File /usr/local/lib/python3.8/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function select at 0x105f4d310>) found for signature:
>>> select(LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True))), list(int64)<iv=[1, 0, 1, 0]>, default=float64)
There are 2 candidate implementations:
- Of which 1 did not match due to:
Overload in function 'np_select': File: numba/np/arraymath.py: Line 4358.
With argument(s): '(Poison<LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True)))>, list(int64)<iv=None>, default=float64)':
Rejected as the implementation raised a specific error:
TypingError: Poison type used in arguments; got Poison<LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True)))>
raised from /usr/local/lib/python3.8/site-packages/numba/core/types/functions.py:236
- Of which 1 did not match due to:
Overload in function 'np_select': File: numba/np/arraymath.py: Line 4358.
With argument(s): '(LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True))), list(int64)<iv=[1, 0, 1, 0]>, default=float64)':
Rejected as the implementation raised a specific error:
NumbaTypeError: condlist must be a List or a Tuple
raised from /usr/local/lib/python3.8/site-packages/numba/np/arraymath.py:4375
During: resolving callee type: Function(<function select at 0x105f4d310>)
During: typing of call at /var/folders/dt/q6vbs0g56s70g4p2kyfj4tvh0000gn/T/ipykernel_61924/130570246.py (5)
File "../../../../../../../../../var/folders/dt/q6vbs0g56s70g4p2kyfj4tvh0000gn/T/ipykernel_61924/130570246.py", line 5:
<source missing, REPL/exec in use?>
I'm not a Numba expert, but my feeling is that there is some syntax error. I've played around passing Numpy arrays and different formats of condlist
and choicelist
, but no luck so far.
Other Notes
The Python function behaves as expected, in this case giving some binary oscillations and then zero:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 50, 500)
plt.plot(x, f(x))
For any Sympy aficionados, the overlying problem here is with using Numba to JIT compile a lambda generated via Sympy from sympy.Piecewise
. A lambda very similar to f(t)
in the above example can be autogenerated by sympy.lambdify
on a Piecewise function.
Numba does not currently implement all Numpy function and the support is sometimes limited. You can find the list of supported functions in the documentation. For np.select
, the documentation states that the support is limited to:
only using homogeneous lists or tuples for the first two arguments,
condlist
andchoicelist
. Additionally, these two arguments can only contain arrays (unlike Numpy that also accepts tuples).
The thing is condlist
is not homogeneous since the 3 first items of the list are arrays while the last is a boolean value. Additionally, choicelist
contains integers while it must contains arrays.
One solution to fix this problem is to use the following code:
def f(t):
condlist = [less(t, 5), less(t, 15), less(t, 20), np.full(t.size, True)]
all_zeros = np.zeros(t.size)
all_ones = np.ones(t.size)
choicelist = [all_ones, all_zeros, all_ones, all_zeros]
return select(condlist, choicelist, default=nan)
However, please do not use this code as it is inefficient. Indeed, it creates many temporary arrays that are slow to create and fill. The code will certainly be memory bound and memory is a scarce resource only slowly improving over the last decades (this is called the "memory wall"). Optimizing such code is hard and Numba is not faster than Numpy for that. In fact, Numpy is already quite efficient to do that since it is implemented in C and most function are carefully optimized. Numba is fast when you use loops and avoid creating (useless) temporary arrays. Put it shortly: Numba likes loops as opposed to Numpy. Here is a much faster solution:
def f(t):
result = np.empty(t.size)
for i in range(t.size):
result[i] = t[i] < 5 or 15 <= t[i] < 20
return result
Note that using a boolean or a short integer (eg. int8
) output type should be even faster (floating-point numbers are slow to compute and takes a lot of space in memory).