Search code examples
pythondesign-patternsrefactoringnumba

numba @jit(nopython=True): Refactor function that uses dictionary with lists as values


I have several functions that I want to use numba @jit(nopython=True) for, but they all rely on the following function:

def getIslands(labels2D,ignoreSea=True):
    islands = {}
    width = labels2D.shape[1]
    height = labels2D.shape[0]
    for x in range(width):
        for y in range(height):
            label = labels2D[y,x]
            if ignoreSea and label == -1:
                continue
            if label in islands:
                islands[label].append((x,y))
            else:
                islands[label] = [(x,y)]
    return islands

Is there any way to redesign this function so it's compatible with numba @jit(nopython=True)? The JIT fails since the function uses features of python dictionaries that are not supported (i.e., dictionaries containing lists as values.)

numba==0.52.0


Solution

  • Dictionaries and lists are not very user-friendly in Numba yet. You first need to declare the type of the dictionary values (outside the function):

    import numba as nb
    
    intTupleList = nb.types.List(nb.types.UniTuple(nb.int_, 2))
    

    Then you can create an empty typed dictionary in the function using nb.typed.typeddict.Dict.empty. The same thing applies for a list with nb.typed.typedlist.List. Here is how:

    @nb.njit('(int_[:,:], bool_)')
    def getIslands(labels2D,ignoreSea=True):
        islands = nb.typed.typeddict.Dict.empty(key_type=nb.int_, value_type=intTupleList)
        width = labels2D.shape[1]
        height = labels2D.shape[0]
        for x in range(width):
            for y in range(height):
                label = labels2D[y,x]
                if ignoreSea and label == -1:
                    continue
                if label in islands:
                    islands[label].append((np.int_(x),np.int_(y)))
                else:
                    islands[label] = nb.typed.typedlist.List([(np.int_(x),np.int_(y))])
        return islands
    

    This is a bit sad that Numba cannot infer the type of the list [(x, y)] yet since it is a bit painful to use nb.typed.typedlist.List (especially with the additional cast that are required because of the mismatch between the nb.int_ type and loop iterators that are nb.int64 on 64-bit machines.