Trying to use numba
to jit my code using the nopython=True
flag, I get an error for functions which receive a function as argument:
import numba
import numpy as np
x = np.random.randn(10,10)
f = lambda x : (x>0)*x
@numba.jit(nopython=True)
def a(x,f): return f(x)**2+x
a(x,f)
The error message received is:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\ProgramData\Anaconda3\envs\pytorch\lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\ProgramData\Anaconda3\envs\pytorch\lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
During: typing of argument at <stdin> (2)
File "<stdin>", line 2:
<source missing, REPL/exec in use?>
This error may have been caused by the following argument(s):
- argument 1: Cannot determine Numba type of <class 'function'>
Ommiting the nopython
flag does work:
@numba.jit
def a(x,f): return f(x)**2+x
a(x,f)
Can I somehow annotate the code or differently define f
in order to inform numba
of its type?