The following error appears because of the last line of code below:
jax.errors.ConcretizationTypeError Abstract tracer value encountered where concrete value is expected...
The problem arose with the
bool
function.
It looks like it is due to the lower
return value from cho_factor
, which _cho_solve
(note underscore) requires as static.
I'm new to jax, so I was hoping that vmap-ing cho_factor
into cho_solve
would just work. What have I done wrong here?
import jax
key = jax.random.PRNGKey(0)
k_y = jax.random.normal(key, (100, 10, 10))
y = jax.random.normal(key, (100, 10, 1))
matmul = jax.vmap(jax.numpy.matmul)
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor)
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve)
k_y = matmul(k_y, jax.numpy.transpose(k_y, (0, 2, 1)))
chol, lower = cho_factor(k_y)
result = cho_solve((chol, lower), y)