I have a problem with Numba typing - I read the manual, but eventually hit a brick wall.
The function in question is a part of a bigger project - though it needs to run fast - Python lists are out of the question, hence I've decided on trying Numba. Sadly, the function fails in nopython=True mode, despite the fact that - according to my understanding - all types are being provided.
The code is as follows:
from Numba import jit, njit, uint8, int64, typeof
@jit(uint8[:,:,:](int64))
def findWhite(cropped):
h1 = int64(0)
for i in cropped:
for j in i:
if np.sum(j) == 765:
h1 = h1 + int64(1)
else:
pass
return h1
also, separately:
print(typeof(cropped))
array(uint8, 3d, C)
print(typeof(h1))
int64
In this case 'cropped' is a large uint8 3D C matrix (RGB tiff file comprehension - PIL.Image). Could someone please explain to a Numba newbie what am I doing wrong?
Have you considered using Numpy? That's often a good intermediate between Python lists and Numba, something like:
h1 = (cropped.sum(axis=-1) == 765).sum()
or
h1 = (cropped == 255).all(axis=-1).sum()
The example code you provide is not valid Numba. Your signature is also incorrect, since the input is a 3D array and the output an integer, it should probably be:
@njit(int64(uint8[:,:,:]))
Looping over the array like you do is not valid code. A close translation of your code would be something like this:
@njit(int64(uint8[:,:,:]))
def findWhite(cropped):
h1 = int64(0)
ys, xs, n_bands = cropped.shape
for i in range(ys):
for j in range(xs):
if cropped[i, j, :].sum() == 765:
h1 += 1
return h1
But that isn't very fast and doesn't beat Numpy on my machine. With Numba it's fine to explicitly loop over every element in an array, this is already a lot faster:
@njit(int64(uint8[:,:,:]))
def findWhite_numba(cropped):
h1 = int64(0)
ys, xs, zs = cropped.shape
for i in range(ys):
for j in range(xs):
incr = 1
for k in range(zs):
if cropped[i, j, k] != 255:
incr = 0
break
h1 += incr
return h1
For a 5000x5000x3 array these are the result for me:
Numpy (h1 = (cropped == 255).all(axis=-1).sum()
):
427 ms ± 6.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
findWhite:
612 ms ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
findWhite_numba:
31 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
A benefit of the Numpy method is that it generalizes to any amount of dimensions.