2

I am using numba's @njit decorator to compile a function that gets used in parallel processes, but it is slower than I expected. Processes that should differ an order of magnitude in execution time take around the same time, which makes it look like there's a lot of compilation overhead.

I have a function

@njit
def foo(ar):
   (do something)
   return ar

and a normal python function

def bar(x):
    (do something)
    return foo(x)

which gets called in parallel processes like

if __name__=="__main__":
    with concurrent.futures.ProcessPoolExecutor(max_workers=maxWorkers) as executor:
        results = executor.map(bar, args)  

Where args is a long list of arguments. Does this mean that foo() gets compiled separately within each process? That would explain the extra overhead. Is there a good solution for this? I could just call foo() once on one of the arguments before spawning the processes, forcing it to compile ahead of time. Is there a better way?

Abelaer
  • 25
  • 4

1 Answers1

1

Multiprocessing cause spawned processes to execute the code that is not in the main section (ie. if __name__ == "__main__"). This indeed includes the compilation of the Numba function. Caching can be used to compile the function once and cache it so subsequent compilation are much faster (the code can be loaded from the cache) assuming the function context is the same (eg. parameter type, dependence on global variables, compilation flags, etc.). This feature is available with @nb.njit(cache=True). For more information about this, please read this section of the documentation. in your case, the main process will compile the function and other ones will load it from the cache.

Note that it is often better to use the multithreading feature of Numba instead of multiprocessing since spawning process is more expensive (both in time and memory usage). That being said, only few functions can be called from a multithreaded Numba context (mainly Numpy and Numba functions).

Jérôme Richard
  • 41,678
  • 6
  • 29
  • 59
  • Thanks! Adding `@njit(cache=True)` solved my problem, the timing now seems much more reasonable (more than an order of magnitude faster). I should have made clear that `bar()` was already in a main section, I have changed this in the question. I used processes instead of threads since bar() is compute, rather than I/O, heavy. As far as I know, that is the general rule of thumb, correct? – Abelaer Jun 01 '22 at 09:31
  • Ok. No the rule for using process for compute tasks and threads for IO tasks only applies for CPython threads, not Numba threads. This problem comes from the GIL in CPython which is not present in parallel Numba operations. Numba thread can be created faster (than CPython processes) and work in shared memory (instead of the slow inter-process communication between processes). – Jérôme Richard Jun 02 '22 at 11:25