Search code examples
pythonpython-3.xnumbapython-packaging

splitting numba functions into separate modules in project for packaging


I have a couple of modules in my project that each contain a few numba functions. I know that on first import the functions get compiled. What I noticed only now is that even if I only import a single function out of a module, seemingly all functions get compiled as the import takes the same amount of time.

I'd like to achieve a more fine-grained approach to this, since for some applications you really only need a single function, hence compiling all is a waste of time.

For that I've split up the functions into separate modules like this:

Project/
|--src/
|  |-- __init__.py
|  |-- fun1.py
|  |-- fun2.py
|  |-- fun3.py 
|  |-- fun4.py
|  |-- ...

with __init__.py including

from .fun1 import fun1
from .fun2 import fun2
...

so they can be imported like from src import fun1.

This seems to work alright, but there's a bit of repetition at the import level, for instance every function needs from numba import jit, a few of them need from numpy import zeros and so on.

So my question is if that's an OK way, or if there's a better approach to packaging many numba functions.

Edit:

Putting all the import statements into __init__.py apparently means also all functions get compiled once one is imported - so there's no gain at all.

I can still import the functions like

from src.fun1 import fun1

which seems to work. But the syntax is a bit clunky.


Solution

  • Interesting question - you're essentially asking how to delay the definition of a function until it's explicitly imported. I think the best way to do this is like you've said, using from src.fun1 import fun1 and having one function per file.

    I think achieving this when you have multiple functions in the same file might be very tricky, so I've relaxed the question to "how can we delay the definition of a function until it's explicitly called (not imported)".

    Trivial Solution

    A trivial way to do this is to just wrap your function inside a dummy outer function.

    This doesn't quite do what we want because subsequent calls to fun1 will result in the the inner function and the numba.jit decorator being recreated, and needing recompilation.

    # main.py
    
    # This lets us see when numba is compiling.
    # See https://numba.pydata.org/numba-doc/dev/reference/envvars.html
    import os
    os.environ["NUMBA_DEBUG_FRONTEND"] = "1"
    
    import fun1
    print("note no numba debug output yet for fun1")
    print("fun1 result is", fun1.fun1(1, 2))
    print("fun1 result is", fun1.fun1(2, 1))
    print("note the function was compiled twice :(")
    
    
    # fun1.py
    
    import numba
    
    # Naively wrap fun1 in another function so it's only declared
    # when the outer function is called.
    def fun1(*args, **kwargs):
        @numba.jit("float32(float32, float32)", cache=False)  # No cache, for debugging
        def __fun1(a, b):
            return a + b
        return __fun1(*args, **kwargs)
    
    

    More advanced solution using a decorator

    The trivial solution is to wrap your function in another function.... smells a lot like a decorator....

    I've created a decorator (the outer decorator) which takes as input another decorator (the inner decorator). The outer decorator applies the inner decorator (numba.jit in this case) only the first time that the function is called. It then re-uses the inner-decorated function on subsequent calls.

    # main.py
    
    # This lets us see when numba is compiling.
    # See https://numba.pydata.org/numba-doc/dev/reference/envvars.html
    import os
    os.environ["NUMBA_DEBUG_FRONTEND"] = "1"
    
    import fun2
    print("note no numba debug output yet for fun2")
    print("fun2 result is", fun2.fun2(3, 4))
    print("fun2 result is", fun2.fun2(5, 6))
    print("note the function was compiled only once :)")
    
    
    # fun2.py
    
    import numba
    from functools import wraps
    
    def delayed(internal_decorator):
        def _delayed(f):
            inner_decorated = None
            @wraps(f)
            def wrapper(*args, **kwds):
                nonlocal inner_decorated
                if inner_decorated is None:
                    inner_decorated = internal_decorator(f)
                return inner_decorated(*args, **kwds)
            return wrapper
        return _delayed
    
    @delayed(numba.jit("float32(float32, float32)", cache=False))
    def fun2(a, b):
        return a * b