This question solved my problem of using vmap
on cho_solve
, is it possible to vectorize
cho_solve
, or does the definition of cho_solve
preclude it from being vectorized? vectorize
seems to need the arguments to all be arrays, whereas cho_solve
takes a tuple as the first argument?
import jax
import jax.numpy as jnp
import jax.scipy as jsp
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
k_y = jax.random.normal(subkey, (3, 5, 10, 10))
y = jnp.broadcast_to(jnp.eye(10), k_y.shape)
matmul = jnp.vectorize(jnp.matmul, signature='(a,b),(b,c)->(a,c)')
cholesky = jnp.vectorize(jsp.linalg.cholesky, excluded={1}, signature='(d,d)->(d,d)')
cho_solve = jnp.vectorize(jsp.linalg.cho_solve, signature='(d,d),(d,d)->(d,d)') # what to put here?
k_y = matmul(k_y, jnp.moveaxis(k_y, -1, -2))
chol = cholesky(k_y, True)
result = cho_solve((chol, True), y)
ValueError: All input arrays must have the same shape.
My use case is that I have an unspecified amount of "batch" dimensions that I want to vmap
over, and vectorize
handles the auto broadcasting beautifully. I can once again write my own cho_solve using solve_triangular
but this seems like a waste. Is it possible for vectorize
to have a similar interface to vmap, which can take nested signatures?