Search code examples
pythonnumpyconditional-statementspredicatenumba

Python Predicates and Conditionals


A user on github reported a bug on the following code using numba no python mode:

from numba import njit
import numpy as np

@njit
def foo():
    a = np.ones(1, np.bool_)
    if a > 0:
        print('truebr')
    else:
        print('falsebr')

foo()

He was told that the expression a > 0 is not a predicate but rather a conditional. In order to fix it he was to "Wrap conditionals in truth to create predicates".

Does this mean that (a > 0) == True would fix the bug that comes up in numba or something else.

https://github.com/numba/numba/pull/3901/commits/598cdd1707fdeb11b8f1d70aef2d3e36ef37bd34. Is this the fix for these types of errors in numba?


Solution

  • In Python (not numba) the function works:

    In [412]: def foo(): 
         ...:     a = np.ones(1, np.bool_) 
         ...:     if a > 0: 
         ...:         print('truebr') 
         ...:     else: 
         ...:         print('falsebr') 
         ...:                                                                                      
    In [413]: foo()                                                                                
    truebr
    

    But if a is an array with more values:

    In [414]: def foo(): 
         ...:     a = np.ones(2, np.bool_) 
         ...:     if a > 0: 
         ...:         print('truebr') 
         ...:     else: 
         ...:         print('falsebr') 
         ...:                                                                                      
    In [415]: foo()                                                                                
    ...
    ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
    

    If I try your function in njit I get a long traceback; too long to show or analyze, but it essentially tells me it can't be done in njit mode. Given the above value error, I'm not surprised. njit isn't allowing for 'just-one' Truth value array.

    As a general rule, when using numba you should iterate. That's its main purpose - to run numpy/python problems that would otherwise be too expensive to iterate. Don't count on numba to handle all the nuances of Python.

    If I change the function to test each element of a, it works:

    In [421]: @numba.njit 
         ...: def foo(): 
         ...:     a = np.array([True]) 
         ...:     for i in a: 
         ...:         if i > 0: 
         ...:             print('truebr') 
         ...:         else: 
         ...:             print('falsebr') 
         ...:                                                                                      
    In [422]: foo()                                                                                
    truebr
    

    An all (or any) wrapper also works:

    In [423]: @numba.njit 
         ...: def foo(): 
         ...:     a = np.array([True]) 
         ...:     if (a > 0).all(): 
         ...:         print('truebr') 
         ...:     else: 
         ...:         print('falsebr') 
         ...:                                                                                      
    In [424]: foo()                                                                                
    truebr