3

The following code does not work:

def get_unique(arr):
    return jnp.unique(arr)

get_unique = jit(get_unique)
get_unique(jnp.ones((10,)))

The error message compains about the use of jnp.unique:

FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=0/1)>
The error arose in jnp.unique()

The documentation on sharp bits explains that jit doesn't work if the shape of internal arrays depends on argument values. This is exactly the case here.

According to the docs, a potential workaround is to specify static parameters. But this doesn't apply to my case. The parameters will change for almost every function call. I have split up my code into a preprocessing step, which performs calculations such as this jnp.unique, and a computation step which can be jitted.

But still I'd like to ask, is there some workaround that I'm not aware of?

lhk
  • 27,458
  • 30
  • 122
  • 201

1 Answers1

1

No, for the reasons you mention, there's currently no way to use jnp.unique on a non-static value.

In similar cases, JAX sometimes adds extra parameters that can be used to specify a static size for the output (for example, the size parameter in jax.numpy.nonzero) but nothing like that is currently implemented for jnp.unique. If that is something you'd like, it would be worth filing a feature request.

jakevdp
  • 77,104
  • 11
  • 125
  • 160