2

I am using the following code to set a particular row of a jax 2D array to a particular value using jax arrays:

zeros_array = jnp.zeros((3, 8))
value = jnp.array([1,2,3,4])
value_2 = jnp.array([1])
value_3 = jnp.array([1,2])
values = jnp.array([value,value_2,value_3])
zeros_array = zeros_array.at[0].set(values)

But, I am receiving the following error:

ValueError: All input arrays must have the same shape.

Upon modifying the jnp to np (numpy) the error disappears. Is there any way to resolve this error? I know one walk around this would be to set each of the separate arrays in the 2D array using at[0,1].set(), at[0,2:n].set().

imk
  • 133
  • 6
  • For numpy I see this error "ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (3,) + inhomogeneous part." – joel Jul 02 '23 at 22:47
  • I tried it out on google colab and it ran in Numpy. Seems like some it varies with Numpy's version as mentioned by @jakevdp – imk Jul 03 '23 at 07:41

1 Answers1

1

What you have in mind is a "ragged array", and no, there is not currently any way to do this in JAX. In older versions of NumPy, this will work by returning an array of dtype object, but in newer versions of NumPy this results in an error because object arrays are generally inconvenient and inefficient to work with (for example, there's no way to efficiently do the equivalent of the index update operation in your last line if the updates are stored in an object array).

Depending on your use-case, there are several workarounds for this you might use in both JAX and NumPy, including storing the rows of your array as a list, or using a padded 2D array representation.

I'll note also that the JAX team is exploring native support for ragged arrays (see e.g. https://github.com/google/jax/pull/16541) but it's still fairly far from being generally useful.

jakevdp
  • 77,104
  • 11
  • 125
  • 160