1

This question solved my problem of using vmap on cho_solve, is it possible to vectorize cho_solve, or does the definition of cho_solve preclude it from being vectorized? vectorize seems to need the arguments to all be arrays, whereas cho_solve takes a tuple as the first argument?

import jax
import jax.numpy as jnp
import jax.scipy as jsp

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)

k_y = jax.random.normal(subkey, (3, 5, 10, 10))
y = jnp.broadcast_to(jnp.eye(10), k_y.shape)

matmul = jnp.vectorize(jnp.matmul, signature='(a,b),(b,c)->(a,c)')
cholesky = jnp.vectorize(jsp.linalg.cholesky, excluded={1}, signature='(d,d)->(d,d)')
cho_solve = jnp.vectorize(jsp.linalg.cho_solve, signature='(d,d),(d,d)->(d,d)')  # what to put here?

k_y = matmul(k_y, jnp.moveaxis(k_y, -1, -2))
chol = cholesky(k_y, True)
result = cho_solve((chol, True), y)

ValueError: All input arrays must have the same shape.

My use case is that I have an unspecified amount of "batch" dimensions that I want to vmap over, and vectorize handles the auto broadcasting beautifully. I can once again write my own cho_solve using solve_triangular but this seems like a waste. Is it possible for vectorize to have a similar interface to vmap, which can take nested signatures?

logan
  • 430
  • 5
  • 11

2 Answers2

0

Here's the solve_triangular solution I managed:

solve_tri = jnp.vectorize(
  jsp.linalg.solve_triangular, excluded={2, 3}, signature='(d,d),(d,d)->(d,d)')
chol = cholesky(k_y, True)
result2 = solve_tri(chol, solve_tri(chol, y, 0, True), 1, True)
result3 = jnp.array([
  jsp.linalg.cho_solve((chol[a, b], True), y[a, b])
  for a in range(k_y.shape[0])
  for b in range(k_y.shape[1])
]).reshape(k_y.shape)
print(jnp.allclose(result2, result3))

True

logan
  • 430
  • 5
  • 11
0

I don't believe you can use vectorize directly with cho_solve. The vectorize API requires that the function take arrays as an argument, while cho_solve takes a tuple as the first argument. The only way you could use vectorize with this function is to wrap it in one with a different API. For example:

cho_solve = jnp.vectorize(
    lambda chol, flag, y: jsp.linalg.cho_solve((chol, flag), y),
    excluded={1}, signature='(d,d),(d,d)->(d,d)')
result = cho_solve(chol, True, y)
jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Would that give the same benefits of the vmapping under the hood, or is that essentially doing the loop from my answer? – logan Jun 30 '23 at 11:56
  • `jnp.vectorize` is built on `vmap`, so it would basically be the same in the end; for some discussion see https://stackoverflow.com/a/69118863/2937831. – jakevdp Jun 30 '23 at 12:01
  • Fantastic, I'll give that a shot. I thought in this case it might just be syntactic sugar like numpy's vectorize is, but it seems not! Thanks for a great library, I'm loving it :) – logan Jun 30 '23 at 12:07
  • `jnp.vectorize` is exactly the same API as numpy's `vectorize`, that's where it gets it's restriction on function signature. For what it's worth, `vmap` would be able to handle the `cho_solve` normal argument structure directly. – jakevdp Jun 30 '23 at 12:09
  • But `vmap` isn't able to handle the arbitrary batch dimensions, right, unless I'm missing something? In that SO you linked you had to explicitly `vmap` twice, which I thought `vectorize` was built to do under the hood? Is there a performance hit from using `vectorize` over multiple `vmap`s? I want to write the function once, and then call it with multiple different batch sizes, without having to explicitly specify all the `vmap` combinations, is `vectorize` the right tool for that? – logan Jun 30 '23 at 14:11
  • 1
    Sorry to be unclear. You’re correct that vmap and vectorize have different semantics. There is no performance issue with vectorize compared to multiple vmaps, because it’s implemented via multiple vmaps. – jakevdp Jun 30 '23 at 18:05