0

I serialized a jitted Numba function to a byte array and now want to deserialize and call it. This works fine for primitive data types with llvm_cfunc_wrapper_name:

import numba, ctypes
import llvmlite.binding as llvm

@numba.njit("f8(f8)")
def foo(x):
    return x + 0.5

# serialize function to byte array
sig = foo.signatures[0]
lib = foo.overloads[sig].library
cfunc_name = foo.overloads[sig].fndesc.llvm_cfunc_wrapper_name
function_bytes = lib._get_compiled_object()

# deserialize function_bytes to func
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
backing_mod = llvm.parse_assembly("")
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
engine.add_object_file(llvm.ObjectFileRef.from_data(function_bytes))
func_ptr = engine.get_function_address(cfunc_name)

func = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)(func_ptr)

print(func(0.25))

But I want to call functions with NumPy arguments. There is a llvm_cpython_wrapper_name for that which uses PyCFunctionWithKeywords, but unfortunately my best guess segfaults:

import numba, ctypes
import llvmlite.binding as llvm
import numpy as np

@numba.njit("f8[:](f8[:])")
def foo(x):
    return x + 0.5

# serialize function to byte array
sig = foo.signatures[0]
lib = foo.overloads[sig].library
cpython_name = foo.overloads[sig].fndesc.llvm_cpython_wrapper_name
function_bytes = lib._get_compiled_object()

# deserialize function_bytes to func
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
backing_mod = llvm.parse_assembly("")
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
engine.add_object_file(llvm.ObjectFileRef.from_data(function_bytes))
func_ptr = engine.get_function_address(cpython_name)

def func(*args, **kwargs):
    py_obj_ptr = ctypes.POINTER(ctypes.py_object)
    return ctypes.CFUNCTYPE(py_obj_ptr, py_obj_ptr, py_obj_ptr, py_obj_ptr)(func_ptr)(
        ctypes.cast(id(None), py_obj_ptr),
        ctypes.cast(id(args), py_obj_ptr),
        ctypes.cast(id(kwargs), py_obj_ptr))

# segfaults here
print(func(np.ones(3)))

Here are some links to Numba source code (unfortunately very hard to follow), which might be helpful to figure this out.

BlueSky
  • 109
  • 1
  • 3
  • Why do you want to serialize a Numba compiled function in the first place. This is very unsafe unless you really know what you are doing. The compiled code is not portable from one machine to another and even certainly not portable from one process to another. I am not even sure you could run it from a different context because the addresses in the compiled code could be absolute ones. Such things will cause segmentation faults. – Jérôme Richard Feb 20 '22 at 13:39
  • The reason is that Numba is notoriously slow at jitting, even with caching enabled. Ahead-of-time compilation should solve this in theory, but requires Visual Studio on Windows, which is not desirable. The cfunc wrapper is sufficiently portable in practice, but I do not know enough about Numba internals to judge whether that is also true with the cpython wrapper. – BlueSky Feb 21 '22 at 08:26

0 Answers0