I have a couple of modules in my project that each contain a few numba
functions. I know that on first import the functions get compiled.
What I noticed only now is that even if I only import a single function out of a module, seemingly all functions get compiled as the import takes the same amount of time.
I'd like to achieve a more fine-grained approach to this, since for some applications you really only need a single function, hence compiling all is a waste of time.
For that I've split up the functions into separate modules like this:
Project/
|--src/
| |-- __init__.py
| |-- fun1.py
| |-- fun2.py
| |-- fun3.py
| |-- fun4.py
| |-- ...
with __init__.py
including
from .fun1 import fun1
from .fun2 import fun2
...
so they can be imported like from src import fun1
.
This seems to work alright, but there's a bit of repetition at the import level, for instance every function needs from numba import jit
, a few of them need from numpy import zeros
and so on.
So my question is if that's an OK way, or if there's a better approach to packaging many numba
functions.
Putting all the import statements into __init__.py
apparently means also all functions get compiled once one is imported - so there's no gain at all.
I can still import the functions like
from src.fun1 import fun1
which seems to work. But the syntax is a bit clunky.
Interesting question - you're essentially asking how to delay the definition of a function until it's explicitly imported.
I think the best way to do this is like you've said, using from src.fun1 import fun1
and having one function per file.
I think achieving this when you have multiple functions in the same file might be very tricky, so I've relaxed the question to "how can we delay the definition of a function until it's explicitly called (not imported)".
A trivial way to do this is to just wrap your function inside a dummy outer function.
This doesn't quite do what we want because subsequent calls to fun1
will result in the the inner function and the numba.jit
decorator being recreated, and needing recompilation.
# main.py
# This lets us see when numba is compiling.
# See https://numba.pydata.org/numba-doc/dev/reference/envvars.html
import os
os.environ["NUMBA_DEBUG_FRONTEND"] = "1"
import fun1
print("note no numba debug output yet for fun1")
print("fun1 result is", fun1.fun1(1, 2))
print("fun1 result is", fun1.fun1(2, 1))
print("note the function was compiled twice :(")
# fun1.py
import numba
# Naively wrap fun1 in another function so it's only declared
# when the outer function is called.
def fun1(*args, **kwargs):
@numba.jit("float32(float32, float32)", cache=False) # No cache, for debugging
def __fun1(a, b):
return a + b
return __fun1(*args, **kwargs)
The trivial solution is to wrap your function in another function.... smells a lot like a decorator....
I've created a decorator (the outer decorator) which takes as input another decorator (the inner decorator).
The outer decorator applies the inner decorator (numba.jit
in this case) only the first time that the function is called. It then re-uses the inner-decorated function on subsequent calls.
# main.py
# This lets us see when numba is compiling.
# See https://numba.pydata.org/numba-doc/dev/reference/envvars.html
import os
os.environ["NUMBA_DEBUG_FRONTEND"] = "1"
import fun2
print("note no numba debug output yet for fun2")
print("fun2 result is", fun2.fun2(3, 4))
print("fun2 result is", fun2.fun2(5, 6))
print("note the function was compiled only once :)")
# fun2.py
import numba
from functools import wraps
def delayed(internal_decorator):
def _delayed(f):
inner_decorated = None
@wraps(f)
def wrapper(*args, **kwds):
nonlocal inner_decorated
if inner_decorated is None:
inner_decorated = internal_decorator(f)
return inner_decorated(*args, **kwds)
return wrapper
return _delayed
@delayed(numba.jit("float32(float32, float32)", cache=False))
def fun2(a, b):
return a * b