I am using the following code to set a particular row of a jax 2D array to a particular value using jax arrays:
zeros_array = jnp.zeros((3, 8))
value = jnp.array([1,2,3,4])
value_2 = jnp.array([1])
value_3 = jnp.array([1,2])
values = jnp.array([value,value_2,value_3])
zeros_array = zeros_array.at[0].set(values)
But, I am receiving the following error:
ValueError: All input arrays must have the same shape.
Upon modifying the jnp to np (numpy) the error disappears. Is there any way to resolve this error? I know one walk around this would be to set each of the separate arrays in the 2D array using at[0,1].set(), at[0,2:n].set().