Important note: I need everything to be jit compatible here, otherwise my problem is trivial :)
I have a jax numpy array such as:
a = jnp.array([1,5,3,4,5,6,7,2,9])
First I filter it considering a value, let's assume that I only keep values that are < 5
a = jnp.where((a < 5), x=a, y=jnp.nan)
# a is now [ 1. nan 3. 4. nan nan nan 2. nan]
I want to keep only non-nan values: [ 1. 3. 4. 2.]
and I will then use this array for other operations.
But more importantly, during execution of my program, this code will be executed multiple times with a threshold value that will change (i.e. it won't always be 5).
Hence, the shape of the final array will change too. Here is my problem with jit compilation, I don't know how to make it jit compatible since the shape depends on how many elements comply to the threshold condition.