Search code examples
pythonnumpyimage-processingnumbadithering

How to speed up a python function with numba


I'm trying to speed up my implementation of Floyd-Steinberg's dithering algorithm with numba. After going through the beginner's guide, I added the @jit decorator to my code:

def findClosestColour(pixel):
    colors = np.array([[255, 255, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [0, 128, 0], [253, 134, 18]])
    distances = np.sum(np.abs(pixel[:, np.newaxis].T - colors), axis=1)
    shortest = np.argmin(distances)
    closest_color = colors[shortest]
    return closest_color

@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def floydDither(img_array):
    height, width, _ = img_array.shape
    for y in range(0, height-1):
        for x in range(1, width-1):
            old_pixel = img_array[y, x, :]
            new_pixel = findClosestColour(old_pixel)
            img_array[y, x, :] = new_pixel
            quant_error = new_pixel - old_pixel
            img_array[y, x+1, :] =  img_array[y, x+1, :] + quant_error * 7/16
            img_array[y+1, x-1, :] =  img_array[y+1, x-1, :] + quant_error * 3/16
            img_array[y+1, x, :] =  img_array[y+1, x, :] + quant_error * 5/16
            img_array[y+1, x+1, :] =  img_array[y+1, x+1, :] + quant_error * 1/16
    return img_array

However, I get thrown the following error:

Untyped global name 'findClosestColour': Cannot determine Numba type of <class 'function'>

I think I understand that numba doesn't know the type of findClosestColour but I just got started with numba and don't know how to handle the error.

Here's the code I used to test the function:

image = cv2.imread('logo.jpeg')
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
im_out = floydDither(img)

And Here's the test image I used.


Solution

  • First of all, it is not possible to call pure-Python functions from Numba nopython jitted functions (aka njit functions). This is because Numba needs to track types at compile time to generate an efficient binary.

    Moreover, Numba cannot compile the expression pixel[:, np.newaxis].T because of np.newaxis which appear not to be supported yet (probably because np.newaxis is None). You can use pixel.reshape(3, -1).T instead.

    Note that you should be careful about the types because doing a - b when both variables are of type np.uint8 results in a possible overflow (eg. 0 - 1 == 255, or even more surprizing: 0 - 256 = 65280 when b is a literal integer and a of type np.uint8). Note that the array is computed in-place and that pixels are written before


    The generated code will not be very efficient although Numba make a good job. You can iterate over the colors yourself using a loop to find the minimum index. This is a bit better because it does not generate many small temporary arrays. You can also specify the types so that Numba will compile the function ahead of time. That being said. This also make the code lower-level and so more verbose/harder-to-maintain.

    Here is an optimized implementation:

    @nb.njit('int32[::1](uint8[::1])')
    def nb_findClosestColour(pixel):
        colors = np.array([[255, 255, 255], [255, 0, 0], [0, 0, 255], [255, 255, 0], [0, 128, 0], [253, 134, 18]], dtype=np.int32)
        r,g,b = pixel.astype(np.int32)
        r2,g2,b2 = colors[0]
        minDistance = np.abs(r-r2) + np.abs(g-g2) + np.abs(b-b2)
        shortest = 0
        for i in range(1, colors.shape[0]):
            r2,g2,b2 = colors[i]
            distance = np.abs(r-r2) + np.abs(g-g2) + np.abs(b-b2)
            if distance < minDistance:
                minDistance = distance
                shortest = i
        return colors[shortest]
    
    @nb.njit('uint8[:,:,::1](uint8[:,:,::1])')
    def nb_floydDither(img_array):
        assert(img_array.shape[2] == 3)
        height, width, _ = img_array.shape
        for y in range(0, height-1):
            for x in range(1, width-1):
                old_pixel = img_array[y, x, :]
                new_pixel = nb_findClosestColour(old_pixel)
                img_array[y, x, :] = new_pixel
                quant_error = new_pixel - old_pixel
                img_array[y, x+1, :] =  img_array[y, x+1, :] + quant_error * 7/16
                img_array[y+1, x-1, :] =  img_array[y+1, x-1, :] + quant_error * 3/16
                img_array[y+1, x, :] =  img_array[y+1, x, :] + quant_error * 5/16
                img_array[y+1, x+1, :] =  img_array[y+1, x+1, :] + quant_error * 1/16
        return img_array
    

    The naive version is 14 times faster while the last one is 19 times faster.