A user on github reported a bug on the following code using numba no python mode:
from numba import njit
import numpy as np
@njit
def foo():
a = np.ones(1, np.bool_)
if a > 0:
print('truebr')
else:
print('falsebr')
foo()
He was told that the expression a > 0
is not a predicate but rather a conditional.
In order to fix it he was to "Wrap conditionals in truth to create predicates".
Does this mean that (a > 0) == True
would fix the bug that comes up in numba or something else.
https://github.com/numba/numba/pull/3901/commits/598cdd1707fdeb11b8f1d70aef2d3e36ef37bd34. Is this the fix for these types of errors in numba?
In Python (not numba
) the function works:
In [412]: def foo():
...: a = np.ones(1, np.bool_)
...: if a > 0:
...: print('truebr')
...: else:
...: print('falsebr')
...:
In [413]: foo()
truebr
But if a
is an array with more values:
In [414]: def foo():
...: a = np.ones(2, np.bool_)
...: if a > 0:
...: print('truebr')
...: else:
...: print('falsebr')
...:
In [415]: foo()
...
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
If I try your function in njit
I get a long traceback; too long to show or analyze, but it essentially tells me it can't be done in njit
mode. Given the above value error, I'm not surprised. njit
isn't allowing for 'just-one' Truth value array.
As a general rule, when using numba
you should iterate. That's its main purpose - to run numpy/python
problems that would otherwise be too expensive to iterate. Don't count on numba
to handle all the nuances of Python.
If I change the function to test each element of a
, it works:
In [421]: @numba.njit
...: def foo():
...: a = np.array([True])
...: for i in a:
...: if i > 0:
...: print('truebr')
...: else:
...: print('falsebr')
...:
In [422]: foo()
truebr
An all
(or any
) wrapper also works:
In [423]: @numba.njit
...: def foo():
...: a = np.array([True])
...: if (a > 0).all():
...: print('truebr')
...: else:
...: print('falsebr')
...:
In [424]: foo()
truebr