Search code examples
pythonnumba

Is it possible to create a numba dict whose key type is UniTuple inside of a function


I want to instantiate a numba Dict inside a function and I want the key type to be a tuple of three floats. To do so a wrote the following code :

import numba


@numba.njit
def foo():
    local_dict = numba.typed.Dict.empty(
        key_type=numba.types.UniTuple(numba.float64, 3),
        value_type=numba.float64,
    )
    return 1


if __name__ == '__main__':
    foo()

Unfortunately this code fails to compile (error message bellow).
However when I instantiate local_dict at the module level with the exact same code it compiles successfully.
I also tried to change the key type to float64 and it worked, suggesting (like the error message) that the problem comes from the UniTuple type.
So my question is : how to declare a dict with a UniTuple as key inside of a function.

Here is the full error message :

Traceback (most recent call last):
  File "/home/louis/PycharmProjects/Bac_a_sable/numba_sandbox.py", line 19, in <module>
    foo()
  File "/home/louis/.venvs/Bac_a_sable/lib/python3.9/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/home/louis/.venvs/Bac_a_sable/lib/python3.9/site-packages/numba/core/dispatcher.py", line 361, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'UniTuple' of type Module(<module 'numba.core.types' from '/home/louis/.venvs/Bac_a_sable/lib/python3.9/site-packages/numba/core/types/__init__.py'>)

File "numba_sandbox.py", line 8:
def foo():
    <source elided>
        # key_type=numba.float64, value_type=numba.float64,
        key_type=numba.types.UniTuple(dtype=numba.float64, count=3), value_type=numba.float64,
        ^

During: typing of get attribute at /home/louis/PycharmProjects/Bac_a_sable/numba_sandbox.py (8)

File "numba_sandbox.py", line 8:
def foo():
    <source elided>
        # key_type=numba.float64, value_type=numba.float64,
        key_type=numba.types.UniTuple(dtype=numba.float64, count=3), value_type=numba.float64,
        ^


Process finished with exit code 1

Solution

  • The docs state that "Type-expression is not supported in jit functions"

    import numba
    
    from numba.types import UniTuple
    
    // declare types _outside_ of function definition
    value_float = numba.float64
    key_float = UniTuple(numba.float64, 3)
    
    @numba.njit
    def foo():
        local_dict = numba.typed.Dict.empty(
            key_type=key_float,
            value_type=value_float
            )
        return local_dict
    
    
    if __name__ == '__main__':
        print(foo()) // prints: {}