This is a minimal example of the real larger problem I am facing. Consider the function below:
import jax.numpy as jnp
def test(x):
return jnp.sum(x)
I tried to vectorize it by:
v_test = jax.vmap(test)
My inputs to test
look like:
x1 = jnp.array([1,2,3])
x2 = jnp.array([4,5,6,7])
x3 = jnp.array([8,9])
x4 = jnp.array([10])
and my input to v_test
is:
x = [x1, x2, x3, x4]
If I try:
v_test(x)
I get the error below:
ValueError: vmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
([3, 4, 2, 1],)
Is there a way to vectorize test
over a list of unequal length arrays?
I could avoid this by padding so the arrays have the same length, however, padding is not desired.