Search code examples
pythondictionarytuplesnumbajit

How to create a dictionary with tuple keys in a numba njit fuction


Im pretty inexperienced with numba (and posting questions) so hopefully this isn't a miss-specified question.

I am trying to create a jitted function that involves a dictionary. I want the dictionary to have tuples as keys, and floats as values. Below is some code from the numba help found on the numba docs, that I've used to help demonstrate my question.

I understand numba wants variables types to be specified. The problem I think is that I am not specifying the right numba type as the dictionary key inside the function. I've looked at this question but still cant figure out what to do.

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict

# Make array type.  Type-expression is not supported in jit functions.
float_array = types.float64[:]

@njit
def foo():
    list_out=[]
    # Make dictionary
    d = Dict.empty(
        key_type=types.Tuple, #<= I suppose im not putting the right 'type' here
        value_type=float_array,
    )
    # an example of how I would like to fill the dictionary
    d[(1,1)] = np.arange(3).astype(np.float64)
    d[(2,2)] = np.arange(3, 6).astype(np.float64)
    list_out.append(d[(2,2)])
    return list_out

list_out = foo()

Any help or guidance is appreciated. Thanks for your time!


Solution

  • types.Tuple is an incomplete type and so not a valid one. You need to specify the type of the items in the tuple. In this case, you can use types.UniTuple(types.int32, 2) as a complete key type (a tuple containing two 32-bit integers). Here is the resulting code:

    import numpy as np
    from numba import njit
    from numba import types
    from numba.typed import Dict
    
    # Make key type with two 32-bit integer items.
    key_type = types.UniTuple(types.int32, 2)
    
    # Make array type.  Type-expression is not supported in jit functions.
    float_array = types.float64[:]
    
    @njit
    def foo():
        list_out=[]
        # Make dictionary
        d = Dict.empty(
            key_type=key_type, 
            value_type=float_array,
        )
        # an example of how I would like to fill the dictionary
        d[(1,1)] = np.arange(3).astype(np.float64)
        d[(2,2)] = np.arange(3, 6).astype(np.float64)
        list_out.append(d[(2,2)])
        return list_out
    
    list_out = foo()
    

    By the way, be aware that arange accept a dtype in parameter so you can use np.arange(3, dtype=np.float64) directly which is more efficient when using astype.