Numba documentation suggests the following code should compile
@njit()
def accuracy(x,y,z):
x.argsort(axis=1)
# compare accuracy, this code works without the above line
accuracy_y = int(((np.equal(y, x).mean())*100)%100)
accuracy_z = int(((np.equal(z, x).mean())*100)%100)
return accuracy_y,accuracy_z
It fails on x.argsort()
, I have also tried the following with and without axis arguments
np.argsort(x)
np.sort(x)
x.sort()
However I get the following failed to compile error (or similar):
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function sort at 0x000001B3B2CD2EE0>) found for signature:
>>> sort(array(int64, 2d, C))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload of function 'sort': File: numba\core\typing\npydecl.py: Line 665.
With argument(s): '(array(int64, 2d, C))':
No match.
During: resolving callee type: Function(<function sort at 0x000001B3B2CD2EE0>)
File "accuracy.py", line 148:
def accuracy(x,lm,sfn):
<source elided>
# accuracy
np.sort(x)
^
What am I missing here?
You could also consider using guvectorize
if it fits your use case. That gives the benefit of being able to specify the axis to sort over. Sorting over more than 1 dimension can be done by repeated calls over a different axis.
@guvectorize("(n)->(n)")
def sort_array(x, out):
out[:] = np.sort(x)
Using a slightly different example array that also has columns unordered.
arr = np.array([
[6,5,4],
[3,2,1],
[9,8,7]],
)
sort_array(arr, out, axis=0)
sort_array(out, out, axis=1)
Shows:
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Making a simple wrapper would allow sorting of an arbitrary number of dimensions at once. I don't think Numba's overload supports guvectorize
, otherwise you could even use it to make np.sort
work inside your jitted functions without having to change anything.
https://numba.pydata.org/numba-doc/dev/extending/overloading-guide.html
Testing the output compared to Numpy:
for _ in range(20):
arr = np.random.randint(0, 99, (9,9))
# numba sorting
out_nb = np.empty_like(arr)
sort_array(arr, out_nb, axis=0)
sort_array(out_nb, out_nb, axis=1)
# numpy sorting
out_np = np.sort(arr, axis=0)
out_np = np.sort(out_np, axis=1)
np.testing.assert_array_equal(out_nb, out_np)