Search code examples
pythonnumbatyping

Numba complains about typing - but all types are being provided


I have a problem with Numba typing - I read the manual, but eventually hit a brick wall.

The function in question is a part of a bigger project - though it needs to run fast - Python lists are out of the question, hence I've decided on trying Numba. Sadly, the function fails in nopython=True mode, despite the fact that - according to my understanding - all types are being provided.

The code is as follows:

from Numba import jit, njit, uint8, int64, typeof

@jit(uint8[:,:,:](int64))
def findWhite(cropped):
    h1 = int64(0)
    for i in cropped:
        for j in i:
            if np.sum(j) == 765:
                h1 = h1 + int64(1)
            else:
                pass
    return h1

also, separately:

print(typeof(cropped))
array(uint8, 3d, C)
print(typeof(h1))
int64

In this case 'cropped' is a large uint8 3D C matrix (RGB tiff file comprehension - PIL.Image). Could someone please explain to a Numba newbie what am I doing wrong?


Solution

  • Have you considered using Numpy? That's often a good intermediate between Python lists and Numba, something like:

    h1 = (cropped.sum(axis=-1) == 765).sum()
    

    or

    h1 = (cropped == 255).all(axis=-1).sum()
    

    The example code you provide is not valid Numba. Your signature is also incorrect, since the input is a 3D array and the output an integer, it should probably be:

    @njit(int64(uint8[:,:,:]))
    

    Looping over the array like you do is not valid code. A close translation of your code would be something like this:

    @njit(int64(uint8[:,:,:]))
    def findWhite(cropped):
    
        h1 = int64(0)    
        ys, xs, n_bands = cropped.shape
    
        for i in range(ys):
            for j in range(xs):
                if cropped[i, j, :].sum() == 765:
                    h1 += 1
    
        return h1
    

    But that isn't very fast and doesn't beat Numpy on my machine. With Numba it's fine to explicitly loop over every element in an array, this is already a lot faster:

    @njit(int64(uint8[:,:,:]))
    def findWhite_numba(cropped):
    
        h1 = int64(0)    
        ys, xs, zs = cropped.shape
    
        for i in range(ys):
            for j in range(xs):
    
                incr = 1
                for k in range(zs):
    
                    if cropped[i, j, k] != 255:
                        incr = 0
                        break
    
                h1 += incr
    
        return h1
    

    For a 5000x5000x3 array these are the result for me:

    Numpy (h1 = (cropped == 255).all(axis=-1).sum()):

    427 ms ± 6.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    findWhite:

    612 ms ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    findWhite_numba:

    31 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    A benefit of the Numpy method is that it generalizes to any amount of dimensions.