Search code examples
llvmmarshallingnumba

Marshaling object code for a Numba function


I have a problem that could be solved by Numba: creating Numpy ufuncs for a query server to (a) coalesce simple operations into a single pass over data, reducing my #1 hotspot (memory bandwidth), and (b) to wrap up third party C functions as ufuncs on the fly, providing more functionality to users of the query system.

I have an accumulator node that splits up the query and collects results and compute nodes that actually run Numpy (distinct computers in a network). If the Numba compilation happens on the compute nodes, it will be duplicated effort since they're working on different data partitions for the same query--- same query means same Numba compilation. Moreover, even the simplest Numba compilation takes 96 milliseconds--- as long as running a query calculation over millions of points, which is time better served on the compute nodes.

So I want to do a Numba compilation once on the accumulate node, then send it to the compute nodes so they can run it. I can guarantee that both have the same hardware, so that object code is compatible.

I've been searching the Numba API for this functionality and haven't found it (apart from a numba.serialize module with no documentation; I'm not sure what its purpose is). The solution might not be a "feature" of the Numba package, but a technique that takes advantage of someone's insider knowledge of Numba and/or LLVM. Does anyone know how to get at the object code, marshal it, and reconstitute it? I can have Numba installed on both machines if that helps, I just can't do anything too expensive on the destination machines.


Solution

  • Okay, it's possible, and the solution makes heavy use of the llvmlite library under Numba.

    Getting the serialized function

    First we define some function with Numba.

    import numba
    
    @numba.jit("f8(f8)", nopython=True)
    def example(x):
      return x + 1.1
    

    We can get access to the object code with

    cres = example.overloads.values()[0]  # 0: first and only type signature
    elfbytes = cres.library._compiled_object
    

    If you print out elfbytes, you'll see that it's an ELF-encoded byte array (bytes object, not a str if you're in Python 3). This is what would go into a file if you were to compile a shared library or executable, so it's portable to any machine with the same architecture, same libraries, etc.

    There are several functions inside this bundle, which you can see by dumping the LLVM IR:

    print(cres.library.get_llvm_str())
    

    The one we want is named __main__.example$1.float64 and we can see its type signature in the LLVM IR:

    define i32 @"__main__.example$1.float64"(double* noalias nocapture %retptr, { i8*, i32 }** noalias nocapture readnone %excinfo, i8* noalias nocapture readnone %env, double %arg.x) #0 {
    entry:
      %.14 = fadd double %arg.x, 1.100000e+00
      store double %.14, double* %retptr, align 8
      ret i32 0
    }
    

    Take note for future reference: the first argument is a pointer to a double that gets overwritten with the result, the second and third arguments are pointers that never get used, and the last argument is the input double.

    (Also note that we can get the function names programmatically with [x.name for x in cres.library._final_module.functions]. The entry point that Numba actually uses is cres.fndesc.mangled_name.)

    We transmit this ELF and function signature from the machine that does all the compiling to the machine that does all the computing.

    Reading it back

    Now on the compute machine, we're going to use llvmlite with no Numba at all (following this page). Initialize it:

    import llvmlite.binding as llvm
    
    llvm.initialize()
    llvm.initialize_native_target()
    llvm.initialize_native_asmprinter()  # yes, even this one
    

    Create an LLVM execution engine:

    target = llvm.Target.from_default_triple()
    target_machine = target.create_target_machine()
    backing_mod = llvm.parse_assembly("")
    engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
    

    And now hijack its caching mechanism to have it load our ELF, named elfbytes:

    def object_compiled_hook(ll_module, buf):
        pass
    
    def object_getbuffer_hook(ll_module):
        return elfbytes
    
    engine.set_object_cache(object_compiled_hook, object_getbuffer_hook)
    

    Finalize the engine as though we had just compiled an IR, but in fact we skipped that step. The engine will load our ELF, thinking it's coming from its disk-based cache.

    engine.finalize_object()
    

    We should now find our function in this engine's space. If the following returns 0L, something's wrong. It should be a function pointer.

    func_ptr = engine.get_function_address("__main__.example$1.float64")
    

    Now we need to interpret func_ptr as a ctypes function we can call. We have to set up the signature manually.

    import ctypes
    pdouble = ctypes.c_double * 1
    out = pdouble()
    
    pointerType = ctypes.POINTER(None)
    dummy1 = pointerType()
    dummy2 = pointerType()
    
    #                        restype first   then argtypes...
    cfunc = ctypes.CFUNCTYPE(ctypes.c_int32, pdouble, pointerType, pointerType, ctypes.c_double)(func_ptr)
    

    And now we can call it:

    cfunc(out, dummy1, dummy2, ctypes.c_double(3.14))
    print(out[0])
    # 4.24, which is 3.14 + 1.1. Yay!
    

    More complications

    If the JITed function has array inputs (after all, you want to do the tight loop over many values in the compiled code, not in Python), Numba generates code that recognizes Numpy arrays. The calling convention for this is quite complex, including pointers-to-pointers to exception objects and all the metadata that accompanies a Numpy array as separate parameters. It does not generate an entry point that you can use with Numpy's ctypes interface.

    However, it does provide a very high-level entry point, which takes a Python *args, **kwds as arguments and parses them internally. Here's how you use that.

    First, find the function whose name starts with "cpython.":

    name = [x.name for x in cres.library._final_module.functions if x.name.startswith("cpython.")][0]
    

    There should be exactly one of them. Then, after serialization and deserialization, get its function pointer using the method described above:

    func_ptr = engine.get_function_address(name)
    

    and cast it with three PyObject* arguments and one PyObject* return value. (LLVM thinks these are i8*.)

    class PyTypeObject(ctypes.Structure):
        _fields_ = ("ob_refcnt", ctypes.c_int), ("ob_type", ctypes.c_void_p), ("ob_size", ctypes.c_int), ("tp_name", ctypes.c_char_p)
    
    class PyObject(ctypes.Structure):
        _fields_ = ("ob_refcnt", ctypes.c_int), ("ob_type", ctypes.POINTER(PyTypeObject))
    
    PyObjectPtr = ctypes.POINTER(PyObject)
    
    cpythonfcn = ctypes.CFUNCTYPE(PyObjectPtr, PyObjectPtr, PyObjectPtr, PyObjectPtr)(fcnptr)
    

    The first of these three arguments is a closure (global variables that the function accesses), and I'm going to assume we didn't need that. Use explicit arguments instead of closures. We can use the fact that CPython's id() implementation returns the pointer value to make PyObject pointers.

        def wrapped(*args, **kwds):
            closure = ()
            return cpythonfcn(ctypes.cast(id(closure), PyObjectPtr), ctypes.cast(id(args), PyObjectPtr), ctypes.cast(id(kwds), PyObjectPtr))
    

    Now the function can be called as

    wrapped(whatever_numpy_arguments, ...)
    

    just like the original Numba dispatcher function.

    Bottom line

    After all that, was it worth it? Doing the end-to-end compilation with Numba--- the easy way--- takes 50 ms for this simple function. Asking for -O3 instead of the default -O2, I can make this 40% slower.

    Splicing in a pre-compiled ELF file, however, takes 0.5 ms: a factor of 100 faster. Moreover, compilation times will increase with more complex functions but the splicing-in procedure should always take 0.5 ms for any function.

    For my application, this is absolutely worth it. It means that I can perform computations on 10 MB at a time and be spending most of my time computing (doing real work), rather than compiling (preparing to work). Scale this up by a factor of 100 and I'd have to perform computations on 1 GB at a time. Since a machine is limited to order-of 100 GB and it has to be shared among order-of 100 processes, I'd be in greater danger of hitting resource limitations, load balancing issues, etc., because the problem would be too granular.

    But for other applications, 50 ms is nothing. It all depends on your application.