I have a function that looks like this
@jax.jit
def f(R):
tr = jnp.trace(R)
r00 = R[0, 0]
r01 = R[0, 1]
r02 = R[0, 2]
r10 = R[1, 0]
r11 = R[1, 1]
r12 = R[1, 2]
r20 = R[2, 0]
r21 = R[2, 1]
r22 = R[2, 2]
condw = tr > 0
condx = (r00 > r11) and (r00 > r22)
condy = (r11 > r22)
# ... do some more things based on the conditions
where R
is a 3x3 DeviceArray
. When I try JIT-ing this function as shown above, I get the following error:
File "/path/to/my/file", line 90, in f
condx = (r00 > r11) and (r00 > r22)
File "/Users/me/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/core.py", line 544, in __bool__
def __bool__(self): return self.aval._bool(self)
File "/Users/me/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/core.py", line 989, in error
raise ConcretizationTypeError(arg, fname_context)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function f at /path/to/my/file:75 for jit, this concrete value was not available in Python because it depends on the value of the argument 'R'.
I'm not really sure what's wrong with computing this boolean value that prevents this function from being JIT-ted.
condx = (r00 > r11) and (r00 > r22)
Any hints would be much appreciated - thanks!