I serialized a jitted Numba function to a byte array and now want to deserialize and call it. This works fine for primitive data types with llvm_cfunc_wrapper_name
:
import numba, ctypes
import llvmlite.binding as llvm
@numba.njit("f8(f8)")
def foo(x):
return x + 0.5
# serialize function to byte array
sig = foo.signatures[0]
lib = foo.overloads[sig].library
cfunc_name = foo.overloads[sig].fndesc.llvm_cfunc_wrapper_name
function_bytes = lib._get_compiled_object()
# deserialize function_bytes to func
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
backing_mod = llvm.parse_assembly("")
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
engine.add_object_file(llvm.ObjectFileRef.from_data(function_bytes))
func_ptr = engine.get_function_address(cfunc_name)
func = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)(func_ptr)
print(func(0.25))
But I want to call functions with NumPy arguments. There is a llvm_cpython_wrapper_name
for that which uses PyCFunctionWithKeywords
, but unfortunately my best guess segfaults:
import numba, ctypes
import llvmlite.binding as llvm
import numpy as np
@numba.njit("f8[:](f8[:])")
def foo(x):
return x + 0.5
# serialize function to byte array
sig = foo.signatures[0]
lib = foo.overloads[sig].library
cpython_name = foo.overloads[sig].fndesc.llvm_cpython_wrapper_name
function_bytes = lib._get_compiled_object()
# deserialize function_bytes to func
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
backing_mod = llvm.parse_assembly("")
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
engine.add_object_file(llvm.ObjectFileRef.from_data(function_bytes))
func_ptr = engine.get_function_address(cpython_name)
def func(*args, **kwargs):
py_obj_ptr = ctypes.POINTER(ctypes.py_object)
return ctypes.CFUNCTYPE(py_obj_ptr, py_obj_ptr, py_obj_ptr, py_obj_ptr)(func_ptr)(
ctypes.cast(id(None), py_obj_ptr),
ctypes.cast(id(args), py_obj_ptr),
ctypes.cast(id(kwargs), py_obj_ptr))
# segfaults here
print(func(np.ones(3)))
Here are some links to Numba source code (unfortunately very hard to follow), which might be helpful to figure this out.