Since array is immutable in Jax, so when one updates N indexes, it creates N arrays with
x = x.at[idx].set(y)
With hundreds of updates per training cycle, it will ultimately create hundreds of arrays if not millions. This seems a little wasteful, is there a way to update multiple index at one go? Does anyone know if there is overhead or if it's significant? Am I overlook at this?