Numba allows to create C-callbacks directly in python with the @cfunc-decorator ( https://numba.pydata.org/numba-doc/0.42.0/user/cfunc.html ):
@cfunc("float64(float64)")
def square(x):
return x**2
To clarify, the resulting function is a pure C-function, which can then be called directly from C-code.
Is there an equivalent functionality available in JAX ( https://jax.readthedocs.io/en/latest/# )?
I have been searching for a while but couldn't find anything. I would appreciate any tips.
No, JAX doesn't provide any API similar to Numba's cfunc
.