Search code examples
pythonjitnumba

Numba: How to turn on/off just-in-time compilation programmatically (setting NUMBA_DISABLE_JIT environment variable)?


I have written a script which is intended to test the efficacy of the @numba.jit decorators that I've added to several functions. I want to first exercise the annotated functions without the just-in-time compilation, then do it again with the @numba.jit decorators in play, in order to then compare the two results.

I've tried doing this by modifying the value of the NUMBA_DISABLE_JIT environment variable via os.environ, but I'm not sure yet that this is having the desired effect. For example:

# run first without and then with numba's just-in-time compilation
for flag in [1, 0]:

    # enable/disable numba's just-in-time compilation
    os.environ["NUMBA_DISABLE_JIT"] = str(flag)

    # time an arbitrary number of iterations of the JIT decorated function
    start = time.time()
    for _ in range(1000):
        expensive_function()
    end = time.time()

    # display the elapsed time
    if flag == 0:
        preposition = "with"
    else:
        preposition = "without"
    print("Elapsed time " + preposition + " numba: {t}".format(t=(end - start)))

Is the setting of the environment variable DISABLE_NUMBA_JIT above actually having the effect of disabling/enabling the JIT compilation of all functions decorated with @numba.jit as I assume? If not then is there a better way to skin this cat?


Solution

  • I think that flag only has impact on the first call of expensive_function so isn't doing what you you'd like.

    With numba you can always access the original python function with .py_func, so that could be a simpler way to do this

    import numba
    
    @numba.njit
    def expensive_function(arr):
        ans = 0.0
        for a in arr:
            ans += a
        return ans
    
    arr = np.random.randn(1_000_000)
    
    In [21]: %timeit expensive_function(arr)
    # 1.51 ms ± 24.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    
    %timeit expensive_function.py_func(arr)
    # 134 ms ± 11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)