4

Since array is immutable in Jax, so when one updates N indexes, it creates N arrays with

x = x.at[idx].set(y)

With hundreds of updates per training cycle, it will ultimately create hundreds of arrays if not millions. This seems a little wasteful, is there a way to update multiple index at one go? Does anyone know if there is overhead or if it's significant? Am I overlook at this?

move37
  • 79
  • 4

2 Answers2

4

You can perform multiple updates in a single operation using the syntax you mention. For example:

import jax.numpy as jnp

x = jnp.zeros(10)
idx = jnp.array([3, 5, 7, 9])
y = jnp.array([1, 2, 3, 4])

x = x.at[idx].set(y)
print(x)
# [0. 0. 0. 1. 0. 2. 0. 3. 0. 4.]

You're correct that outside JIT, each update operation will create an array copy. But within JIT-compiled functions, the compiler is able to perform such updates in-place when it is possible (for example, when the original array is not referenced again). You can read more at JAX Sharp Bits: Array Updates.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Is JAX JIT sophisticated enough to perform in-place updates on something like `x = x.at[1:].add(+y)` then `x = x.at[:-1].add(-y)` where `y` has size one less than `x` and only lives in the same scope (no side effect)? – DavidJ Dec 15 '22 at 11:03
  • 1
    Yes, I believe this will be performed in-place, although I don't think the updates will be fused: i.e. it will compile to two in-place scatter-update operations. You can see the compiled HLO with something like `jax.jit(lambda x, y: x.at[1:].add(y).at[:-1].add(-y)).lower(x, y).compile().as_text()` – jakevdp Dec 15 '22 at 14:00
  • Any thoughts on how to do the batch updates for a 2D array? – move37 Jan 03 '23 at 06:18
  • 1
    You can use 2D indices in various ways depending on what exactly you want to do. I'd suggest posting a new question describing your particular situation. – jakevdp Jan 03 '23 at 13:09
3

This sounds very like a job for scatter update. I'm not really familiar with Jax itself, but major frameworks have it:

https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scatter.html

What it does in a nutshell:

  1. setup your output tensor (x)
  2. accumulate required updates in the other tensor (y in your case)
  3. accumulate in list/tensor indices where to apply you updates (create tensor/list full of index)
  4. feed 1)-3) to scatter_updated
CaptainTrunky
  • 1,562
  • 2
  • 15
  • 23
  • 1
    Quick note for clarification: `x.at[idx].set(y)` is equivalent to `lax.scatter`; it's just a more convenient higher-level API for the same functionality. I'd recommend in general to avoid direct use of `lax.scatter` and other low-level `lax` functions in most cases. – jakevdp Dec 06 '22 at 15:52
  • I assume the recommendation comes because the `lax` API is subject to change whereas the higher-level API will remain consistent. – DavidJ Dec 15 '22 at 22:24