1

what is the proper way to update multiple indexes of 2D (or multiple dimension) Jax array at once?

This is a follow up question to my previous on batch update for an 1D Jax array with the goal to avoid creating millions of arrays during training.

I have tried:

x = jnp.zeros((3,3))

# Update 1 index at a time
x = x.at[2, 2].set(1) # or x = x.at[(2, 2)].set(1)
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 1.]]
# Nice, it works.
# but how about 2 indexes at the same time?
x = jnp.zeros((3,3))
x = x.at[(1, 0), (0, 1) ].set([1, 3])
print(x)
[[0. 3. 0.]
 [1. 0. 0.]
 [0. 0. 0.]]

It works again, but when I tried to update 3 or more indexes,
x = x.at[(1, 0), (0, 1), (1,1) ].set([1, 3, 6])
print(x)
IndexError: Too many indices for array: 3 non-None/Ellipsis indices for dim 2.

I have spent some time browsing through Jax's documentation, but I couldn't find the best way. Any help?

move37
  • 79
  • 4

1 Answers1

2

The values you give in .at are rows and columns, rather than pairs of rows/columns. This is hinted at in the error message referring to dim 2 (dim 0 is rows, dim 1 is columns, there is no dim 2). This should give the desired behavior

x = x.at[(1, 0, 1), (0, 1, 1) ].set([1, 3, 6])
[[0. 3. 0.]
 [1. 6. 0.]
 [0. 0. 0.]]
Brutus
  • 93
  • 5