Search code examples
pythonnumbajax

JAX python C callbacks


Numba allows to create C-callbacks directly in python with the @cfunc-decorator ( https://numba.pydata.org/numba-doc/0.42.0/user/cfunc.html ):

@cfunc("float64(float64)") 
def square(x):
    return x**2

To clarify, the resulting function is a pure C-function, which can then be called directly from C-code.

Is there an equivalent functionality available in JAX ( https://jax.readthedocs.io/en/latest/# )?

I have been searching for a while but couldn't find anything. I would appreciate any tips.


Solution

  • No, JAX doesn't provide any API similar to Numba's cfunc.