I would like to wrap an external function with numba, but require that the resulting function is able to be cached with njit(cache=True)
like I can do with the numba implementation (I'm just using dgesv
as an example):
import numba as nb
import numpy as np
@nb.njit(cache=True)
def dgesv_numba(A, b):
return np.linalg.solve(A, b)
I've tried with ctypes:
import ctypes as ct
from ctypes.util import find_library
from numba import types
from numba.core import cgutils
from numba.extending import intrinsic
@intrinsic
def ptr_from_val(typingctx, data):
# from https://stackoverflow.com/questions/51541302/how-to-wrap-a-cffi-function-in-numba-taking-pointers
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder, args[0])
return ptr
sig = types.CPointer(data)(data)
return sig, impl
ptr_int = ct.POINTER(ct.c_int)
ptr_double = ct.POINTER(ct.c_double)
argtypes = [
ptr_int, # n
ptr_int, # nrhs
ptr_double, # a
ptr_int, # lda
ptr_int, # ipiv
ptr_double, # b
ptr_int, # ldb
ptr_int, # info
]
lapack_ctypes = ct.CDLL(find_library("lapack"))
_dgesv_ctypes = lapack_ctypes.dgesv_
_dgesv_ctypes.argtypes = argtypes
_dgesv_ctypes.restype = None
# Or get it from scipy
# addr = nb.extending.get_cython_function_address(
# "scipy.linalg.cython_lapack", "dgesv"
# )
# functype = ct.CFUNCTYPE(None, *argtypes)
# _dgesv_ctypes = functype(addr)
@nb.njit(cache=True)
def args(A, b):
if b.ndim == 1:
_b = b[:, None] # .reshape(-1, 1) # change to reshape numba < 0.57
nrhs = np.int32(1)
else:
_b = b.T.copy() # Dunno? is there a better way to do this?
nrhs = np.int32(b.shape[1])
n = np.int32(A.shape[0])
info = np.int32(0)
ipiv = np.zeros((n,), dtype=np.int32)
return _b, n, nrhs, ipiv, info
@nb.njit(cache=True)
def dgesv_ctypes(A, b):
b, n, nrhs, ipiv, info = args(A, b)
_dgesv_ctypes(
ptr_from_val(n),
ptr_from_val(nrhs),
A.T.copy().ctypes, # Dunno? is there a better way to do this?
ptr_from_val(n),
ipiv.ctypes,
b.ctypes,
ptr_from_val(n),
ptr_from_val(info),
)
if info:
raise Exception("something went wrong")
return b.T
and with cffi:
import cffi
ffi = cffi.FFI()
ffi.cdef(
"""
void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb,
int *info);
"""
)
lapack_cffi = ffi.dlopen(find_library("lapack"))
_dgesv_cffi = lapack_cffi.dgesv_
@nb.njit(cache=True)
def dgesv_cffi(A, b):
b, n, nrhs, ipiv, info = args(A, b)
_dgesv_cffi(
ptr_from_val(n),
ptr_from_val(nrhs),
ffi.from_buffer(A.T.copy()),
ptr_from_val(n),
ffi.from_buffer(ipiv),
ffi.from_buffer(b),
ptr_from_val(n),
ptr_from_val(info),
)
if info:
raise Exception("something went wrong")
return b.T
but in both cases I get a warning that the function cannot be cached as I have used ctypes pointers:
/var/folders/v7/vq2l7f812yd450mn3wwmrhtc0000gn/T/ipykernel_4390/2568069903.py:79: NumbaWarning: Cannot cache compiled function "dgesv_ctypes" as it uses dynamic globals (such as ctypes pointers and large global arrays)
@nb.njit(cache=True)
/var/folders/v7/vq2l7f812yd450mn3wwmrhtc0000gn/T/ipykernel_4390/2568069903.py:97: NumbaWarning: Cannot cache compiled function "dgesv_cffi" as it uses dynamic globals (such as ctypes pointers and large global arrays)
@nb.njit(cache=True)
I have managed to do it with WAP:
class Dgesv(nb.types.WrapperAddressProtocol):
def __wrapper_address__(self):
return ct.cast(lapack_ctypes.dgesv_, ct.c_voidp).value
def signature(self):
return nb.types.void(
nb.types.CPointer(nb.int32), # n
nb.types.CPointer(nb.int32), # nrhs
nb.types.CPointer(nb.float64), # a
nb.types.CPointer(nb.int32), # lda
nb.types.CPointer(nb.int32), # ipiv
nb.types.CPointer(nb.float64), # b
nb.types.CPointer(nb.int32), # ldb
nb.types.CPointer(nb.int32), # info
)
@nb.njit(cache=True)
def dgesv_wap(f, A, b):
b, n, nrhs, ipiv, info = args(A, b)
f(
ptr_from_val(n),
ptr_from_val(nrhs),
A.T.copy().ctypes,
ptr_from_val(n),
ipiv.ctypes,
b.ctypes,
ptr_from_val(n),
ptr_from_val(info),
)
if info:
raise Exception("something went wrong")
return b.T
but the resulting function is significantly slower than the other methods and it isn't really what I want as you have to pass the function as an argument for the caching to work:
rng = np.random.default_rng()
for i in range(3, 5):
N = i
A = rng.random((N, N))
x = rng.random((N, 1000))
b = A @ x
_ctypes = dgesv_ctypes(A.copy(), b.copy())
_cffi = dgesv_cffi(A.copy(), b.copy())
_wap = dgesv_wap(Dgesv(), A.copy(), b.copy())
_numba = dgesv_numba(A, b)
assert np.allclose(_ctypes, _numba)
assert np.allclose(_cffi, _numba)
assert np.allclose(_wap, _numba)
assert np.allclose(x, _numba)
print("all good")
%timeit dgesv_ctypes(A.copy(), b.copy())
%timeit dgesv_cffi(A.copy(), b.copy())
%timeit dgesv_wap(Dgesv(), A.copy(), b.copy())
%timeit dgesv_numba(A, b)
Output:
all good
56.5 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
55.8 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
89.6 µs ± 2.57 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
59.7 µs ± 894 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
so how do I do it whilst retaining the performance of the other implementations?
Ok, actually it is in the documentation, you have to use types.ExternalFunction
(though the docstring for this says it is for internal use only) and load the library with llvmlite but then this implementation cannot be called outside of an njit
decorated function:
from ctypes.util import find_library
from llvmlite import binding
from numba import types, njit
binding.load_library_permanently(find_library("lapack"))
ptr_int = types.CPointer(types.int32)
ptr_double = types.CPointer(types.float64)
_dgesv = types.ExternalFunction("dgesv_", types.float64(
ptr_int, #n
ptr_int, # nrhs
ptr_double, # a
ptr_int, # lda
ptr_int, # ipiv
ptr_double, # b
ptr_int, # ldb
ptr_int, # info
))
@njit(cache=True)
def dgesv_external_function(A, b):
b, n, nrhs, ipiv, info = args(A, b)
_dgesv(
ptr_from_val(n),
ptr_from_val(nrhs),
A.T.copy().ctypes,
ptr_from_val(n),
ipiv.ctypes,
b.ctypes,
ptr_from_val(n),
ptr_from_val(info),
)
if info:
raise Exception("something went wrong")
return b.T
Timings:
A = rng.random((5, 5))
x = rng.random((5, 1000))
b = A @ x
_ctypes = dgesv_ctypes(A.copy(), b.copy())
_cffi = dgesv_cffi(A.copy(), b.copy())
_wap = dgesv_wap(Dgesv(), A.copy(), b.copy())
_numba = dgesv_numba(A, b)
_ext = dgesv_external_function(A.copy(), b.copy())
assert np.allclose(_ctypes, _numba)
assert np.allclose(_cffi, _numba)
assert np.allclose(_wap, _numba)
assert np.allclose(x, _numba)
assert np.allclose(_ext, _numba)
%timeit dgesv_ctypes(A.copy(), b.copy())
%timeit dgesv_cffi(A.copy(), b.copy())
%timeit dgesv_wap(Dgesv(), A.copy(), b.copy())
%timeit dgesv_numba(A, b)
%timeit dgesv_external_function(A.copy(), b.copy())
Output:
43.3 µs ± 5.02 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
40.5 µs ± 2.69 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
115 µs ± 34.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
53.5 µs ± 5.41 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
17.4 µs ± 544 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)