3

In autograd/numpy I could do:

q[q<0] = 0.0

How can I do the same thing in JAX?

I tried import numpy as onp and using that to create arrays, but that doesn't seem to work.

Andriy Drozdyuk
  • 58,435
  • 50
  • 171
  • 272

1 Answers1

8

JAX arrays are immutable, so in-place index assignment statements cannot work. Instead, jax provides the jax.ops submodule, which provides functionality to create updated versions of arrays.

Here is an example of a numpy index assignment and the equivalent JAX index update:

import numpy as np
q = np.arange(-5, 5)
q[q < 0] = 0
print(q)
# [0 0 0 0 0 0 1 2 3 4]

import jax.numpy as jnp
q = jnp.arange(-5, 5)
q = q.at[q < 0].set(0)  # NB: this does not modify the original array,
                        # but rather returns a modified copy.
print(q)
# [0 0 0 0 0 0 1 2 3 4]

Note that in op-by-op mode, the JAX version does create multiple copies of the array. However when used within a JIT compilation, XLA can often fuse such operations and avoid copying of data.

jakevdp
  • 77,104
  • 11
  • 125
  • 160