I think this is a simple task, but I could not find a solution on the web to this. I have a external C++ library, which I'm using in my Python code, returning a ctypes.POINTER(ctypes.c_float)
to me. I want to pass an array of these pointers to a jax.vmap
function. The problem is that jax
does not accept the ctypes.POINTER(ctypes.c_float)
type. So, can I somehow cast this pointer to an ordinary int
. Technically, this is clearly possible. But how do I do this in Python?
Here is an example:
lib = ctypes.cdll.LoadLibrary(lib_path)
lib.foo.argtypes = None
lib.foo.restype = ctypes.POINTER(ctypes.c_float)
bar = jax.vmap(lambda : dummy lib.foo())(jax.numpy.empty(16))
x = jax.numpy.empty(16, 256, 256, 1)
y = jax.vmap(lib.bar, in_axes = (0, 1))(x, bar)
So, I want to invoke lib.foo
16-times so that I have an array bar
containing all the pointers. Then I want to invoke another library function lib.bar
which expects bar
together with another (batched) parameter x
.
The problem is that jax claims that ctypes.POINTER(ctypes.c_float)
is not a valid jax type. This is why I think the solution is to cast the pointers to int
s and store those int
s in bar
instead.
Listing:
[SO]: C function called from Python via ctypes returns incorrect value (@CristiFati's answer) - a common pitfall when working with CTypes (calling functions)
[Python.Docs]: ctypes - A foreign function library for Python
Here's a piece of code exemplifying how to handle pointers and their addresses. The trick is to use ctypes.addressof (documented in the 2nd URL).
code00.py:
#!/usr/bin/env python
import ctypes as cts
import sys
CType = cts.c_float
CTypePtr = cts.POINTER(CType)
def ctype_pointer(seq): # Helper
CTypeArr = (CType * len(seq))
ctype_arr = CTypeArr(*seq)
return cts.cast(ctype_arr, CTypePtr)
def pointer_elements(addr, count): # Helper
return tuple(CType.from_address(addr + i * cts.sizeof(CType)).value for i in range(count))
def main(*argv):
seq = (2.718182, -3.141593, 1.618034, -0.618034, 0)
ptr = ctype_pointer(seq)
print(f"Pointer: {ptr}")
print(f"\nPointer elements: {tuple(ptr[i] for i in range(len(seq)))}") # Check if pointer has correct data
ptr_addr = cts.addressof(ptr.contents) # @TODO - cfati: Straightforward
print(f"\nAddress: {ptr_addr} (0x{ptr_addr:016X})\nElements from address: {pointer_elements(ptr_addr, len(seq))}")
ptr_addr0 = cts.cast(ptr, cts.c_void_p).value # @TODO - cfati: Alternative
print(f"\nAddresses match: {ptr_addr == ptr_addr0}")
if __name__ == "__main__":
print(
"Python {:s} {:03d}bit on {:s}\n".format(
" ".join(elem.strip() for elem in sys.version.split("\n")),
64 if sys.maxsize > 0x100000000 else 32,
sys.platform,
)
)
rc = main(*sys.argv[1:])
print("\nDone.\n")
sys.exit(rc)
Notes:
Although it adds a bit of complexity, I introduced the CType "layer" to show that it should work with any type, not just float (as long as the values in the sequence are of that type)
The only truly relevant lines are those marked with @TODO
Output:
(py_pc064_03.08_test0_lancer) [cfati@cfati-5510-0:/mnt/e/Work/Dev/StackExchange/StackOverflow/q078366208]> python ./code00.py Python 3.8.19 (default, Apr 6 2024, 17:58:10) [GCC 11.4.0] 064bit on linux Pointer: <__main__.LP_c_float object at 0x7203e97e7d40> Pointer elements: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0) Address: 125361127594576 (0x00007203E97A9A50) Elements from address: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0) Addresses match: True Done.