Search code examples
pythonlambdacythontypedef

Cython defining type for functions


I'm trying to make a cython-built slice-sampling library. A generic slice sampling library, where you supply a log-density, a starter value, and get a result. Working on the univariate model now. Based on the response here, I've come up with the following.

So i have a function defined in cSlice.pyx:

cdef double univariate_slice_sample(f_type_1 logd, double starter, 
                                        double increment_size = 0.5):
    some stuff
    return value

I have defined in cSlice.pxd:

cdef ctypedef double (*f_type_1)(double)
cdef double univariate_slice_sample(f_type_1 logd, double starter, 
                                               double increment_size = *)

where logd is a generic univariate log-density.

In my distribution file, let's say cDistribution.pyx, I have the following:

from cSlice cimport univariate_slice_sample, f_type_1

cdef double log_distribution(alpha_k, y_k, prior):
    some stuff
    return value

cdef double _sample_alpha_k_slice(
        double starter,
        double[:] y_k,
        Prior prior,
        double increment_size
        ):
    cdef f_type_1 f = lambda alpha_k: log_distribution(alpha_k), y_k, prior)
    return univariate_slice_sample(f, starter, increment_size)

cpdef double sample_alpha_k_slice(
        double starter,
        double[:] y_1,
        Prior prior,
        double increment_size = 0.5
        ):
    return _sample_alpha_1_slice(starter, y_1, prior, increment_size)

the wrapper because apparently lambda's aren't allowed in cpdef's.

When I try compiling the distribution file, I get the following:

cDistribution.pyx:289:22: Cannot convert Python object to 'f_type_1'

pointing at the cdef f_type_1 f = ... line.

I'm unsure of what else to do. I want this code to maintain C speed, and importantly not hit the GIL. Any ideas?


Solution

  • You can jit a C-callback/wrapper for any Python function (cast to a pointer from a Python-object cannot done implicitly), how for example explained in this SO-post.

    However, at its core the function will stay slow pure Python function. Numba gives you possibility to create real C-callbacks via a @cfunc. Here is a simplified example:

    from numba import cfunc 
    @cfunc("float64(float64)")
    def id_(x):
        return x
    

    and this is how it could be used:

    %%cython
    ctypedef double(*f_type)(double)
    
    cdef void c_print_double(double x, f_type f):
        print(2.0*f(x))
    
    import numba
    expected_signature = numba.float64(numba.float64)
    def print_double(double x,f):
        # check the signature of f:
        if not f._sig == expected_signature:
            raise TypeError("cfunc has not the right type")
        # it is not possible to cast a Python object to a pointer directly,
        # so we cast the address first to unsigned long long
        c_print_double(x, <f_type><unsigned long long int>(f.address))
    

    And now:

    print_double(1.0, id_)
    # 2.0
    

    We need to check the signature of the cfunc-object during the run time, otherwise the casting <f_type><unsigned long long int>(f.address) would "work" also for the functions with wrong signature - only to (possible) crash during the call or giving funny hard to debug errors. I'm just not sure that my method is the best though - even if it works:

    ...
    @cfunc("float32(float32)")
    def id3_(x):
        return x
    
    print_double(1.0, id3_)
    # TypeError: cfunc has not the right type