Search code examples
pythonnumba

Pass a function and all its subfunctions into njit


So I recently discovered Numba and I am thoroughly amazed by it. When trying it out I've used a bubblesort function as the test function, but since my bubblesort function calls another function I get errors when calling njit on it.

I've tackled this by first calling njit on my bubblesort subfunction, and then having my bubblesort call the njit subfunction, and it works, but it forces me to define two bubblesort functions when trying to compare. I'm wondering if there's another way of doing this.

This is what I'm doing:

def bytaintill(l):
    changed = False
    for i in range(len(l) - 1):
        if l[i] > l[i+1]:
            l[i], l[i+1] = l[i+1], l[i]
            changed = True
    return changed


bytaintill_njit = njit()(bytaintill)

def bubblesort(l):
    not_done = True
    while not_done:
        not_done = bytaintill_njit(l)
    return

def bubble(l):
    not_done = True
    while not_done:
        not_done = bytaintill(l)
    return

bubblesort_njit = njit()(bubblesort)

Solution

  • To expand on my comment, you don't need to define new functions but can also map the jit-ed version to the same name. Usually, the most convenient way to do so is to use the @jit decorator (or @njit which is short for @jit(nopython=True)).

    from numba import njit
    
    @njit
    def bytaintill(l):
        changed = False
        for i in range(len(l) - 1):
            if l[i] > l[i+1]:
                l[i], l[i+1] = l[i+1], l[i]
                changed = True
        return changed
    
    @njit
    def bubble(l):
        not_done = True
        while not_done:
            not_done = bytaintill(l)
        return
    

    For benchmarking purposes, you can simply comment out the decorators. If you prefer to be able to go forth and back between jit-ed and python versions, you could instead try something like this:

    from numba import njit
    
    do_jit = True  # set to True or False
    
    def bytaintill(l):
        changed = False
        for i in range(len(l) - 1):
            if l[i] > l[i+1]:
                l[i], l[i+1] = l[i+1], l[i]
                changed = True
        return changed
    
    def bubble(l):
        not_done = True
        while not_done:
            not_done = bytaintill(l)
        return
    
    if do_jit:
        bytaintill = njit()(bytaintill)
        bubble = njit()(bubble)