3

I'm new to JAX and writing code that JIT compiles is proving to be quite hard for me. I am trying to achieve the following:

Given an (n,n) array mat in JAX, I would like to add a (1,n) or an (n,1) array to an arbitrary row or column, respectively, of the original array mat.

If I wanted to add a row array, r, to the third row, the numpy equivalent would be,

# if mat is a numpy array
mat[2,:] = mat[2,:] + r

The only way I know how to update an element of an array in JAX is using array.at[i].set(). I am not sure how one can use this to update a row or a column without explicitly using a for-loop.

  • you'll find it easier if you adopt the functional style of jax. Don't mutate arrays, define new ones in terms of the old ones. It will take practice to learn how to write things effectively like that but it's the recommended approach, and will work nicely with jit – joel Jan 07 '23 at 22:07

1 Answers1

5

JAX arrays are immutable, so you cannot do in-place modifications of array entries. But you can accomplish similar results with the np.ndarray.at syntax. For example, the equivalent of

mat[2,:] = mat[2,:] + r

would be

mat = mat.at[2,:].set(mat[2,:] + r)

But you can use the add method to be more efficient in this case:

mat = mat.at[2:].add(r)

Here is an example of adding a row and column array to a 2D array:

import jax.numpy as jnp

mat = jnp.zeros((5, 5))

# Create 2D row & col arrays, as in question
row = jnp.ones(5).reshape(1, 5)
col = jnp.ones(5).reshape(5, 1)

mat = mat.at[1:2, :].add(row)
mat = mat.at[:, 2:3].add(col)

print(mat)
# [[0. 0. 1. 0. 0.]
#  [1. 1. 2. 1. 1.]
#  [0. 0. 1. 0. 0.]
#  [0. 0. 1. 0. 0.]
#  [0. 0. 1. 0. 0.]]

See JAX Sharp Bits: In-Place Updates for more discussion of this.

jakevdp
  • 77,104
  • 11
  • 125
  • 160