3

Important note: I need everything to be jit compatible here, otherwise my problem is trivial :)

I have a jax numpy array such as:

a = jnp.array([1,5,3,4,5,6,7,2,9])

First I filter it considering a value, let's assume that I only keep values that are < 5

a = jnp.where((a < 5), x=a, y=jnp.nan)
# a is now [ 1. nan  3.  4. nan nan nan  2. nan]

I want to keep only non-nan values: [ 1. 3. 4. 2.] and I will then use this array for other operations.

But more importantly, during execution of my program, this code will be executed multiple times with a threshold value that will change (i.e. it won't always be 5).

Hence, the shape of the final array will change too. Here is my problem with jit compilation, I don't know how to make it jit compatible since the shape depends on how many elements comply to the threshold condition.

Valentin Macé
  • 1,150
  • 1
  • 10
  • 25

1 Answers1

5

JAX's JIT is not currently compatible with arrays of dynamic (data-dependent) shape, so there is no way to do what your question asks.

There is some experimental work in progress on handling dynamic shapes within JAX transforms like JIT (see https://github.com/google/jax/pull/9335) but I'm not certain when it will be available to use.

The usual workaround for this is to re-express your computations in terms of statically-shaped arrays with a fill value; for example, you could use something like this:

a = jnp.where((a < 5), size=len(a), fill_value=np.nan)

This will create an array of the same length as a, with non-nan values at the front, and filled with nan values at the end.

jakevdp
  • 77,104
  • 11
  • 125
  • 160