Search code examples
pythonarraysnumpysortingnumba

Sorting an numpy array in numba no-python mode


Numba documentation suggests the following code should compile

@njit()
def accuracy(x,y,z):
    x.argsort(axis=1)
    # compare accuracy, this code works without the above line  
    accuracy_y = int(((np.equal(y, x).mean())*100)%100)
    accuracy_z = int(((np.equal(z, x).mean())*100)%100)
    return accuracy_y,accuracy_z

It fails on x.argsort(), I have also tried the following with and without axis arguments

np.argsort(x)
np.sort(x)
x.sort()

However I get the following failed to compile error (or similar):

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function sort at 0x000001B3B2CD2EE0>) found for signature:
 
 >>> sort(array(int64, 2d, C))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload of function 'sort': File: numba\core\typing\npydecl.py: Line 665.
    With argument(s): '(array(int64, 2d, C))':
   No match.

During: resolving callee type: Function(<function sort at 0x000001B3B2CD2EE0>)



File "accuracy.py", line 148:
def accuracy(x,lm,sfn):
    <source elided>
    # accuracy
    np.sort(x)
    ^

What am I missing here?


Solution

  • You could also consider using guvectorize if it fits your use case. That gives the benefit of being able to specify the axis to sort over. Sorting over more than 1 dimension can be done by repeated calls over a different axis.

    @guvectorize("(n)->(n)")
    def sort_array(x, out):
        out[:] = np.sort(x)
    

    Using a slightly different example array that also has columns unordered.

    arr = np.array([
        [6,5,4],
        [3,2,1],
        [9,8,7]],
    )
    
    sort_array(arr, out, axis=0)
    sort_array(out, out, axis=1)
    

    Shows:

    array([[1, 2, 3],
           [4, 5, 6],
           [7, 8, 9]])
    

    Making a simple wrapper would allow sorting of an arbitrary number of dimensions at once. I don't think Numba's overload supports guvectorize, otherwise you could even use it to make np.sort work inside your jitted functions without having to change anything.

    https://numba.pydata.org/numba-doc/dev/extending/overloading-guide.html

    Testing the output compared to Numpy:

    for _ in range(20):
        
        arr = np.random.randint(0, 99, (9,9))
    
        # numba sorting
        out_nb = np.empty_like(arr)
        sort_array(arr, out_nb, axis=0)
        sort_array(out_nb, out_nb, axis=1)
    
        # numpy sorting
        out_np = np.sort(arr, axis=0)
        out_np = np.sort(out_np, axis=1)
    
        np.testing.assert_array_equal(out_nb, out_np)