I'm new to JAX and writing code that JIT compiles is proving to be quite hard for me. I am trying to achieve the following:
Given an (n,n)
array mat
in JAX, I would like to add a (1,n)
or an (n,1)
array to an arbitrary row or column, respectively, of the original array mat
.
If I wanted to add a row array, r
, to the third row, the numpy equivalent would be,
# if mat is a numpy array
mat[2,:] = mat[2,:] + r
The only way I know how to update an element of an array in JAX is using array.at[i].set()
. I am not sure how one can use this to update a row or a column without explicitly using a for-loop.