Search code examples
pythonnumba

How to wrap an external function in numba such that the resulting function is cacheable?


I would like to wrap an external function with numba, but require that the resulting function is able to be cached with njit(cache=True) like I can do with the numba implementation (I'm just using dgesv as an example):

import numba as nb
import numpy as np

@nb.njit(cache=True)
def dgesv_numba(A, b):
    return np.linalg.solve(A, b)

I've tried with ctypes:

import ctypes as ct
from ctypes.util import find_library

from numba import types
from numba.core import cgutils
from numba.extending import intrinsic

@intrinsic
def ptr_from_val(typingctx, data):
    # from https://stackoverflow.com/questions/51541302/how-to-wrap-a-cffi-function-in-numba-taking-pointers
    def impl(context, builder, signature, args):
        ptr = cgutils.alloca_once_value(builder, args[0])
        return ptr

    sig = types.CPointer(data)(data)
    return sig, impl

ptr_int = ct.POINTER(ct.c_int)
ptr_double = ct.POINTER(ct.c_double)

argtypes = [
    ptr_int,  # n
    ptr_int,  # nrhs
    ptr_double,  # a
    ptr_int,  # lda
    ptr_int,  # ipiv
    ptr_double,  # b
    ptr_int,  # ldb
    ptr_int,  # info
]
lapack_ctypes = ct.CDLL(find_library("lapack"))
_dgesv_ctypes = lapack_ctypes.dgesv_
_dgesv_ctypes.argtypes = argtypes
_dgesv_ctypes.restype = None

# Or get it from scipy
# addr = nb.extending.get_cython_function_address(
#     "scipy.linalg.cython_lapack", "dgesv"
# )
# functype = ct.CFUNCTYPE(None, *argtypes)
# _dgesv_ctypes = functype(addr)

@nb.njit(cache=True)
def args(A, b):
    if b.ndim == 1:
        _b = b[:, None]  # .reshape(-1, 1) # change to reshape numba < 0.57
        nrhs = np.int32(1)
    else:
        _b = b.T.copy()  # Dunno? is there a better way to do this?
        nrhs = np.int32(b.shape[1])

    n = np.int32(A.shape[0])
    info = np.int32(0)
    ipiv = np.zeros((n,), dtype=np.int32)
    return _b, n, nrhs, ipiv, info


@nb.njit(cache=True)
def dgesv_ctypes(A, b):
    b, n, nrhs, ipiv, info = args(A, b)
    _dgesv_ctypes(
        ptr_from_val(n),
        ptr_from_val(nrhs),
        A.T.copy().ctypes,  # Dunno? is there a better way to do this?
        ptr_from_val(n),
        ipiv.ctypes,
        b.ctypes,
        ptr_from_val(n),
        ptr_from_val(info),
    )
    if info:
        raise Exception("something went wrong")
    return b.T

and with cffi:

import cffi
ffi = cffi.FFI()
ffi.cdef(
    """
void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb,
    int *info);
"""
)
lapack_cffi = ffi.dlopen(find_library("lapack"))
_dgesv_cffi = lapack_cffi.dgesv_


@nb.njit(cache=True)
def dgesv_cffi(A, b):
    b, n, nrhs, ipiv, info = args(A, b)
    _dgesv_cffi(
        ptr_from_val(n),
        ptr_from_val(nrhs),
        ffi.from_buffer(A.T.copy()),
        ptr_from_val(n),
        ffi.from_buffer(ipiv),
        ffi.from_buffer(b),
        ptr_from_val(n),
        ptr_from_val(info),
    )
    if info:
        raise Exception("something went wrong")
    return b.T

but in both cases I get a warning that the function cannot be cached as I have used ctypes pointers:

/var/folders/v7/vq2l7f812yd450mn3wwmrhtc0000gn/T/ipykernel_4390/2568069903.py:79: NumbaWarning: Cannot cache compiled function "dgesv_ctypes" as it uses dynamic globals (such as ctypes pointers and large global arrays)
  @nb.njit(cache=True)
/var/folders/v7/vq2l7f812yd450mn3wwmrhtc0000gn/T/ipykernel_4390/2568069903.py:97: NumbaWarning: Cannot cache compiled function "dgesv_cffi" as it uses dynamic globals (such as ctypes pointers and large global arrays)
  @nb.njit(cache=True)

I have managed to do it with WAP:

class Dgesv(nb.types.WrapperAddressProtocol):
    def __wrapper_address__(self):
        return ct.cast(lapack_ctypes.dgesv_, ct.c_voidp).value

    def signature(self):
        return nb.types.void(
            nb.types.CPointer(nb.int32),  # n
            nb.types.CPointer(nb.int32),  # nrhs
            nb.types.CPointer(nb.float64),  # a
            nb.types.CPointer(nb.int32),  # lda
            nb.types.CPointer(nb.int32),  # ipiv
            nb.types.CPointer(nb.float64),  # b
            nb.types.CPointer(nb.int32),  # ldb
            nb.types.CPointer(nb.int32),  # info
        )


@nb.njit(cache=True)
def dgesv_wap(f, A, b):
    b, n, nrhs, ipiv, info = args(A, b)
    f(
        ptr_from_val(n),
        ptr_from_val(nrhs),
        A.T.copy().ctypes,
        ptr_from_val(n),
        ipiv.ctypes,
        b.ctypes,
        ptr_from_val(n),
        ptr_from_val(info),
    )
    if info:
        raise Exception("something went wrong")
    return b.T

but the resulting function is significantly slower than the other methods and it isn't really what I want as you have to pass the function as an argument for the caching to work:

rng = np.random.default_rng()

for i in range(3, 5):
    N = i
    A = rng.random((N, N))
    x = rng.random((N, 1000))
    b = A @ x
    _ctypes = dgesv_ctypes(A.copy(), b.copy())
    _cffi = dgesv_cffi(A.copy(), b.copy())
    _wap = dgesv_wap(Dgesv(), A.copy(), b.copy())
    _numba = dgesv_numba(A, b)
    assert np.allclose(_ctypes, _numba)
    assert np.allclose(_cffi, _numba)
    assert np.allclose(_wap, _numba)
    assert np.allclose(x, _numba)

print("all good")

%timeit dgesv_ctypes(A.copy(), b.copy())
%timeit dgesv_cffi(A.copy(), b.copy())
%timeit dgesv_wap(Dgesv(), A.copy(), b.copy())
%timeit dgesv_numba(A, b)

Output:

all good
56.5 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
55.8 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
89.6 µs ± 2.57 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
59.7 µs ± 894 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

so how do I do it whilst retaining the performance of the other implementations?


Solution

  • Ok, actually it is in the documentation, you have to use types.ExternalFunction (though the docstring for this says it is for internal use only) and load the library with llvmlite but then this implementation cannot be called outside of an njit decorated function:

    from ctypes.util import find_library
    from llvmlite import binding
    from numba import types, njit
    
    binding.load_library_permanently(find_library("lapack"))
    
    ptr_int = types.CPointer(types.int32)
    ptr_double = types.CPointer(types.float64)
    
    _dgesv = types.ExternalFunction("dgesv_", types.float64(
        ptr_int, #n
        ptr_int, # nrhs
        ptr_double, # a
        ptr_int, # lda
        ptr_int, # ipiv
        ptr_double, # b
        ptr_int, # ldb
        ptr_int, # info
    ))
    
    @njit(cache=True)
    def dgesv_external_function(A, b):
        b, n, nrhs, ipiv, info = args(A, b)
        _dgesv(
            ptr_from_val(n),
            ptr_from_val(nrhs),
            A.T.copy().ctypes,
            ptr_from_val(n),
            ipiv.ctypes,
            b.ctypes,
            ptr_from_val(n),
            ptr_from_val(info),
        )
        if info:
            raise Exception("something went wrong")
        return b.T
    

    Timings:

    A = rng.random((5, 5))
    x = rng.random((5, 1000))
    b = A @ x
    
    _ctypes = dgesv_ctypes(A.copy(), b.copy())
    _cffi = dgesv_cffi(A.copy(), b.copy())
    _wap = dgesv_wap(Dgesv(), A.copy(), b.copy())
    _numba = dgesv_numba(A, b)
    _ext = dgesv_external_function(A.copy(), b.copy())
    assert np.allclose(_ctypes, _numba)
    assert np.allclose(_cffi, _numba)
    assert np.allclose(_wap, _numba)
    assert np.allclose(x, _numba)
    assert np.allclose(_ext, _numba)
    
    %timeit dgesv_ctypes(A.copy(), b.copy())
    %timeit dgesv_cffi(A.copy(), b.copy())
    %timeit dgesv_wap(Dgesv(), A.copy(), b.copy())
    %timeit dgesv_numba(A, b)
    %timeit dgesv_external_function(A.copy(), b.copy())
    

    Output:

    43.3 µs ± 5.02 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
    40.5 µs ± 2.69 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
    115 µs ± 34.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
    53.5 µs ± 5.41 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
    17.4 µs ± 544 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)