Search code examples
pythonnumba

Passing config variables to functions so they behave as compile time constants


In numba, I want to pass the config variable to a function as a compile-time constant. Specifically what I want to do is

    @njit
    def physics(config):
        flagA = config.flagA
        flagB = config.flagB
        aNumbaList = List()
        for i in range(100):
            if flagA:
                aNumbaList.append(i)
            else:
                aNumbaList.append(i/10)
        return aNumbaList

If the config variables are compile-time constants, this would have passed, but it is not, so it's giving me errors saying that there are two candidates

There are 2 candidate implementations:
                 - Of which 2 did not match due to:
                 ...
                 ...

I looked at one of numba meeting minutes and found that there was a way to do this Numba Meeting: 2024-03-05 I tried that, but it is still raising the same error.

Here is the code with the error message:

.. code:: ipython3

    from numba import jit, types, njit
    from numba.extending import overload
    from numba.typed import List
    import functools

.. code:: ipython3

    class Config():
        def __init__(self, flagA, flagB):
            self._flagA = flagA
            self._flagB = flagB
    
        @property
        def flagA(self):
            return self._flagA
    
        @property
        def flagB(self):
            return self._flagB

.. code:: ipython3

    @functools.cache
    def obj2strkeydict(obj, config_name):
    
        # unpack object to freevars and close over them
        tmp_a = obj.flagA
        tmp_b = obj.flagB
        assert isinstance(config_name, str)
        tmp_force_heterogeneous = config_name
    
        @njit
        def configurator():
            d = {'flagA': tmp_a,
                 'flagB': tmp_b,
                 'config_name': tmp_force_heterogeneous}
            return d
    
        # return a configuration function that returns a string-key-dict
        # representation of the configuration object.
        return configurator

.. code:: ipython3

    @njit
    def physics(cfig_func):
        config = cfig_func()
        flagA = config['flagA']
        flagB = config['flagB']
        aNumbaList = List()
        for i in range(100):
            if flagA:
                aNumbaList.append(i)
            else:
                aNumbaList.append(i/10)
        return aNumbaList

.. code:: ipython3

    def demo():
        configuration1 = Config(True, False)
        jit_config1 = obj2strkeydict(configuration1, 'config1')
        physics(jit_config1)

.. code:: ipython3

    demo()


::


    ---------------------------------------------------------------------------

    TypingError                               Traceback (most recent call last)

    Cell In[83], line 1
    ----> 1 demo()


    Cell In[82], line 4, in demo()
          2 configuration1 = Config(True, False)
          3 jit_config1 = obj2strkeydict(configuration1, 'config1')
    ----> 4 physics(jit_config1)


    File ~/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
        464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
        465                f"by the following argument(s):\n{args_str}\n")
        466         e.patch_message(msg)
    --> 468     error_rewrite(e, 'typing')
        469 except errors.UnsupportedError as e:
        470     # Something unsupported is present in the user code, add help info
        471     error_rewrite(e, 'unsupported_error')


    File ~/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
        407     raise e
        408 else:
    --> 409     raise e.with_traceback(None)


    TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    - Resolution failure for literal arguments:
    No implementation of function Function(<function impl_append at 0x7fd87d253920>) found for signature:
    
     >>> impl_append(ListType[int64], float64)
    
    There are 2 candidate implementations:
          - Of which 2 did not match due to:
          Overload in function 'impl_append': File: numba/typed/listobject.py: Line 592.
            With argument(s): '(ListType[int64], float64)':
           Rejected as the implementation raised a specific error:
             TypingError: Failed in nopython mode pipeline (step: nopython frontend)
           No implementation of function Function(<intrinsic _cast>) found for signature:
    
            >>> _cast(float64, class(int64))
    
           There are 2 candidate implementations:
                 - Of which 2 did not match due to:
                 Intrinsic in function '_cast': File: numba/typed/typedobjectutils.py: Line 22.
                   With argument(s): '(float64, class(int64))':
                  Rejected as the implementation raised a specific error:
                    TypingError: cannot safely cast float64 to int64. Please cast explicitly.
             raised from /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/typedobjectutils.py:75
           
           During: resolving callee type: Function(<intrinsic _cast>)
           During: typing of call at /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/listobject.py (600)
           
           
           File "../anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/listobject.py", line 600:
               def impl(l, item):
                   casteditem = _cast(item, itemty)
                   ^
    
      raised from /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/typeinfer.py:1086
    
    - Resolution failure for non-literal arguments:
    None
    
    During: resolving callee type: BoundFunction((<class 'numba.core.types.containers.ListType'>, 'append') for ListType[int64])
    During: typing of call at /tmp/ipykernel_9889/739598600.py (11)
    
    
    File "../../../tmp/ipykernel_9889/739598600.py", line 11:
    <source missing, REPL/exec in use?>

Any help or any reference to a related material would really help me. Thank You.


Solution

  • In Numba, global variables are compile-time constant so you can use that to do what you want. Here is an example:

    import numba as nb   # v0.58.1
    
    flagA = True
    
    @nb.njit
    def physics(flagA):
        aNumbaList = nb.typed.List()
        for i in range(100):
            if flagA:
                aNumbaList.append(i)
            else:
                aNumbaList.append(i/10)
        return aNumbaList
    

    This works well without error while passing flagA in parameter results in an error because the items in the if and else are of different types.

    That being said, global variables are not great in term of software engineering, and you may want to compile the function for different configuration at runtime (e.g. based on an initialisation process, while avoiding writing in global variables).

    An alternative solution is to return a function which read variable defined in a parent function so it is considered as a global one to the function and thus a compile-time constant. The variable read by the compiled function can be passed in parameter to the parent one. Here is an example:

    import numba as nb
    
    def make_physics(flagA):
        @nb.njit
        def fun():
            aNumbaList = nb.typed.List()
            for i in range(100):
                if flagA:
                    aNumbaList.append(i)
                else:
                    aNumbaList.append(i/10)
            return aNumbaList
    
        return fun
    
    physics = make_physics(True)  # Compile a specialized function every time it is called
    physics()                     # Call the compiled function generated just before
    

    This does not results in any error too and actually works as intended. Here is the generated assembly code of the physics function showing that there is no runtime check of flagA within the main loop:

        [...]
    
        movq    %rax, %r12                 ; r12 = an allocated Python object (the list?)
        movq    24(%rax), %rax
        movq    %r14, (%rax)
        xorl    %ebx, %ebx                 ; i = 0
        movabsq $NRT_incref, %r13
        movabsq $numba_list_append, %rbp
        leaq    48(%rsp), %r15             ; (r15 is a pointer on i)
    
    .LBB0_6:                               ; Main loop
        movq    %r12, %rcx                 
        callq   *%r13                      ; Call NRT_incref(r12)
        movq    %rbx, 48(%rsp)             
        movq    %r14, %rcx                 
        movq    %r15, %rdx                 
        callq   *%rbp                      ; Call numba_list_append(r14, pointer_of(i))
        testl   %eax, %eax                 
        jne .LBB0_7                        ; Stop the loop if numba_list_append returned a non-zero value
        incq    %rbx                       ; i += 1
        movq    %r12, %rcx                 
        movabsq $NRT_decref, %rax          
        callq   *%rax                      ; Call NRT_decref(r12)
        cmpq    $100, %rbx                 
        jne .LBB0_6                        ; Loop as long as i < 100
    
        [...]
    

    Regarding the actual use-case, memoization and Numba function caching can help to avoid compiling the target function many times for the same configuration.