0

Trying to use numba to jit my code using the nopython=True flag, I get an error for functions which receive a function as argument:

import numba
import numpy as np      
 
x = np.random.randn(10,10)     
f = lambda x : (x>0)*x

@numba.jit(nopython=True)
def a(x,f): return f(x)**2+x

a(x,f)

The error message received is:

  Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\ProgramData\Anaconda3\envs\pytorch\lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\ProgramData\Anaconda3\envs\pytorch\lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
During: typing of argument at <stdin> (2)

File "<stdin>", line 2:
<source missing, REPL/exec in use?>

This error may have been caused by the following argument(s):
- argument 1: Cannot determine Numba type of <class 'function'>

Ommiting the nopython flag does work:

@numba.jit
def a(x,f): return f(x)**2+x

a(x,f)

Can I somehow annotate the code or differently define f in order to inform numba of its type?

Uri Cohen
  • 3,488
  • 1
  • 29
  • 46
  • As an aside, in python parlance, that is just a *function*. Not a "function handle" – juanpa.arrivillaga Nov 12 '22 at 19:46
  • 1
    Anyway, I'm not sure about this, but I think numba will only handle other numba functions correctly, at least, that is what the examples seem to imply in the [docs](https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#functions-as-arguments) – juanpa.arrivillaga Nov 12 '22 at 20:15
  • I don't think numba can compile function, which accepts arbitrary function as an argument. I am not even sure that it can compile it with other jitted function. Numba has a very limited scope. So it tries to fallback to object mode. Nopython does not allow it. – Dimitrius Nov 12 '22 at 20:21

1 Answers1

1

You cannot provide a pure-Python function which is not compiled with Numba to a Numba code executing it in nopython mode. If you really want to do that, you need to use the objmode switch which is experimental and inefficient (so it is not very useful except in very few special cases). The typical solution is simply to compile the function so Numba can call it from a nopython mode:

import numba
import numpy as np      
 
x = np.random.randn(10,10)   
  
@numba.njit
def f(x):
    return (x>0)*x

@numba.njit
def a(x,f):
    return f(x)**2+x

a(x,f)
Jérôme Richard
  • 41,678
  • 6
  • 29
  • 59