So I recently discovered Numba and I am thoroughly amazed by it. When trying it out I've used a bubblesort function as the test function, but since my bubblesort function calls another function I get errors when calling njit on it.
I've tackled this by first calling njit on my bubblesort subfunction, and then having my bubblesort call the njit subfunction, and it works, but it forces me to define two bubblesort functions when trying to compare. I'm wondering if there's another way of doing this.
This is what I'm doing:
def bytaintill(l):
changed = False
for i in range(len(l) - 1):
if l[i] > l[i+1]:
l[i], l[i+1] = l[i+1], l[i]
changed = True
return changed
bytaintill_njit = njit()(bytaintill)
def bubblesort(l):
not_done = True
while not_done:
not_done = bytaintill_njit(l)
return
def bubble(l):
not_done = True
while not_done:
not_done = bytaintill(l)
return
bubblesort_njit = njit()(bubblesort)
To expand on my comment, you don't need to define new functions but can also map the jit-ed version to the same name. Usually, the most convenient way to do so is to use the @jit
decorator (or @njit
which is short for @jit(nopython=True)
).
from numba import njit
@njit
def bytaintill(l):
changed = False
for i in range(len(l) - 1):
if l[i] > l[i+1]:
l[i], l[i+1] = l[i+1], l[i]
changed = True
return changed
@njit
def bubble(l):
not_done = True
while not_done:
not_done = bytaintill(l)
return
For benchmarking purposes, you can simply comment out the decorators. If you prefer to be able to go forth and back between jit-ed and python versions, you could instead try something like this:
from numba import njit
do_jit = True # set to True or False
def bytaintill(l):
changed = False
for i in range(len(l) - 1):
if l[i] > l[i+1]:
l[i], l[i+1] = l[i+1], l[i]
changed = True
return changed
def bubble(l):
not_done = True
while not_done:
not_done = bytaintill(l)
return
if do_jit:
bytaintill = njit()(bytaintill)
bubble = njit()(bubble)