3

This is a minimal example of the real larger problem I am facing. Consider the function below:

import jax.numpy as jnp
def test(x):
    return jnp.sum(x)

I tried to vectorize it by:

v_test = jax.vmap(test)

My inputs to test look like:

x1 = jnp.array([1,2,3])
x2 = jnp.array([4,5,6,7])
x3 = jnp.array([8,9])
x4 = jnp.array([10])

and my input to v_test is:

x = [x1, x2, x3, x4]

If I try:

v_test(x)

I get the error below:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
([3, 4, 2, 1],)

Is there a way to vectorize test over a list of unequal length arrays? I could avoid this by padding so the arrays have the same length, however, padding is not desired.

MOON
  • 2,516
  • 4
  • 31
  • 49

1 Answers1

2

JAX does not support ragged arrays, (i.e. arrays in which each row has a different number of elements) so there is currently no way to use vmap for this kind of data. Your best bet is probably to use a Python for loop:

y = [test(xi) for xi in x]

Alternatively, you might be able to express the operation you have in mind in terms of segment_sum or similar operations. For example:

segments = jnp.concatenate([i * jnp.ones_like(xi) for i, xi in enumerate(x)])
result = jax.ops.segment_sum(jnp.concatenate(x), segments)
print(result)
# [ 6 22 17 10]

Another possibility is to pad the input arrays so that they can fit into a standard, non-ragged 2D array.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Does the HLO representation of comprehension have the same limitations as a native `for` loop (with back-prop in mind)? Also, mapping the second index is not required in the straight reduction example above, but if it was one could use `map_j = jnp.concatenate([jnp.arange(xi.size) for xi in x])` – DavidJ Dec 15 '22 at 10:53
  • I don't understand your question. Perhaps create a new post with more details? Comment threads aren't a great medium for asking additional questions. – jakevdp Dec 15 '22 at 14:02