0

I have a function that looks like this

        @jax.jit
        def f(R):
            tr = jnp.trace(R)

            r00 = R[0, 0]
            r01 = R[0, 1]
            r02 = R[0, 2]
            r10 = R[1, 0]
            r11 = R[1, 1]
            r12 = R[1, 2]
            r20 = R[2, 0]
            r21 = R[2, 1]
            r22 = R[2, 2]

            condw = tr > 0
            condx = (r00 > r11) and (r00 > r22)
            condy = (r11 > r22)
            # ... do some more things based on the conditions

where R is a 3x3 DeviceArray. When I try JIT-ing this function as shown above, I get the following error:

File "/path/to/my/file", line 90, in f
    condx = (r00 > r11) and (r00 > r22)
  File "/Users/me/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/core.py", line 544, in __bool__
    def __bool__(self): return self.aval._bool(self)
  File "/Users/me/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/core.py", line 989, in error
    raise ConcretizationTypeError(arg, fname_context)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at /path/to/my/file:75 for jit, this concrete value was not available in Python because it depends on the value of the argument 'R'.

I'm not really sure what's wrong with computing this boolean value that prevents this function from being JIT-ted.

        condx = (r00 > r11) and (r00 > r22)

Any hints would be much appreciated - thanks!

Carpetfizz
  • 8,707
  • 22
  • 85
  • 146

1 Answers1

0

From #3761, use bitwise operators instead of logical operators.

This works.

condx = (r00 > r11) & (r00 > r22)
Carpetfizz
  • 8,707
  • 22
  • 85
  • 146