Search code examples
pythonjitnumba

Why does this code fail to compile with Numba?


I have a sample code that illustrates my issue. If you run:

import numpy as np
from numba import jit


@jit(nopython=True)
def test():
    arr = np.array([[[11, 12, 13], [11, 12, 13]], [[21, 22, 23], [21, 22, 23]]])

    arr2 = arr[:, 0, :]

    arr3 = arr2.argsort()

    print(arr3)

test()

It will fail with:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of BoundFunction(array.argsort for array(int64, 2d, A)) with parameters ()
During: resolving callee type: BoundFunction(array.argsort for array(int64, 2d, A))
During: typing of call at /home/stark/Work/mmr6/test.py (41)


File "test.py", line 41:
def test():
    <source elided>

    arr3 = arr2.argsort()
    ^

argsort is supposed to argsort on the last axis. Essentially it should give me:

>>>
[[0 1 2]
 [0 1 2]]

I thought copying the arr2 array (with copy()) could solve as it would make the array contiguous in memory (instead of a view), but it fails with the same message except that the type of arr2 in the message is now array(int64, 2d, C) as expected.

Why is it failing and how can I fix it?


Solution

  • This is sadly a current known limitation of Numba. See this issue. Only 1D arrays are supported so far. However, there is a simple workaround in your case:

    import numpy as np
    from numba import jit
    
    
    @jit(nopython=True)
    def test():
        arr = np.array([[[11, 12, 13], [11, 12, 13]], [[21, 22, 23], [21, 22, 23]]])
    
        arr2 = arr[:, 0, :]
    
        arr3 = np.empty(arr2.shape, dtype=arr2.dtype)
        for i in range(arr2.shape[0]):
            arr3[i] = arr2[i, :].argsort()
    
        print(arr3)
    
    test()
    

    Note that even though it is implemented, it will not be faster. See this issue. Actually, there is no reason Numba could be faster for any given Numpy primitive. However, you can write your own version of Numpy primitives manually with Numba and sometimes get a speed up thanks to algorithmic specialization, parallelism or math optimizations (eg. fast-math). Numba is often great when you want to perform an efficient operation not yet/directly available in Numpy and this operation can be trivially implemented using loops.

    Actually, you can use prange of Numba and the JIT parameter parallel=True to speed up the computation a bit assuming argsort is not already running in parallel (AFAIK it should be sequential). This should be a bit faster than the Numpy implementation (which should not run sequentially too) on big arrays (on small arrays, the cost of spawning multiple threads can be bigger than the actual computation).