Search code examples
pythonnumba

Numba: converting numpy array into hashable object


Inside numba .jit(nopython=True) function i am calculating thousands of numpy arrays (1-D, integer data type) and append them to the list. The problem is that some of the arrays appears equal, but i dont need duplicates. So i need an efficient way to check if new array already exists in the list or not.

In python it can be done like this:

import numpy as np
import numba as nb

# @nb.jit(nopython=True)
def foo(n):

    uniques = []
    uniques_set = set()

    for _ in range(n):

        arr = np.random.randint(0, 2, 2)
        arr_hashable = make_hashable(arr)

        if not arr_hashable in uniques_set:
            uniques_set.add(arr_hashable)
            uniques.append(arr)

    return uniques

Ive tried two ways to solve this:

  1. Converting array to tuple and put the tuple inside of a set.

    def make_hashable(arr):
        return tuple(arr)
    

    but unfortunately direct tuple construction doesnt work this way in nopython mode. Ive tried also this way:

    def make_hashable(arr):
        res = ()
        for n in arr:
            res += (n,)
        return res
    

    and other similar workarounds i could think of, but all of them failed in nopython mode with TypeError.

  2. Convert array to string and also put it to the set.

    def make_hashable(arr):
        return arr.tostring()
    

    also tried all possible ways to convert array to string but seems like numba doesnt support string conversion for now

Maybe there are different approaches to check (efficiently) if array is already exists in a list? My numba version is 0.44. Thanks a lot.


Solution

  • I have numba 0.58, but still the only way I know around your problem is to use a callback into object mode to hash the array. Like this:

    import numpy as np
    import numba as nb
    
    def make_hashable(arr):
        return hash(arr.tobytes())
    
    @nb.jit(nopython=True)
    def foo(n):
        uniques = []
        uniques_set = set()
        for _ in range(n):
            arr = np.random.randint(0, 2, 2)
            with nb.objmode(arr_hashable='intp'):
                arr_hashable = make_hashable(arr)
    
            if arr_hashable not in uniques_set:
                uniques_set.add(arr_hashable)
                uniques.append(arr)
        return uniques
    
    foo(100)
    # => [array([0, 0]), array([0, 1]), array([1, 1]), array([1, 0])]
    

    EDIT:

    In case you want to handle collisions properly, you can use a dict instead of the set which maps the hash values to lists of the previously seen arrays. The code is a bit more complicated of course:

    import numpy as np
    import numba as nb
    
    def make_hashable(arr):
        return hash(arr.tobytes())
    
    list_type = nb.types.ListType(int64[:])
    
    @nb.jit(nopython=True)
    def foo(n):
        uniques = []
        uniques_dict = nb.typed.Dict.empty(nb.int64, list_type)
        for _ in range(n):
            arr = np.random.randint(0, 2, 2)
            with nb.objmode(arr_hashable='intp'):
                arr_hashable = make_hashable(arr)
    
            is_seen = False
            seen_arrs = uniques_dict.get(arr_hashable)
            if seen_arrs is not None:
                for seen_arr in seen_arrs:
                    if np.array_equal(arr, seen_arr):
                        is_seen = True
                        break
        
            if not is_seen:
                if seen_arrs is None:
                    seen_arrs = nb.typed.List.empty_list(int64[:])
                    uniques_dict[arr_hashable] = seen_arrs
                seen_arrs.append(arr)
                uniques.append(arr)
        return uniques
    
    foo(100)
    # => [array([1, 0]), array([0, 0]), array([0, 1]), array([1, 1])]
    

    I had to define list_type outside, otherwise compilation failed. You can test the code by using a bad hash function which always return 1 for example.

    Hopefully, in the future, numba will support bytes and all this becomes unnecessary. There is already a ticket for this: https://github.com/numba/numba/issues/5149