I have a simple loss function that looks like this
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
I would like to optimize over the parameter r
and use some static parameters x
and y
to compute the residual. All parameters in question are DeviceArrays
.
In order to JIT this, I tried doing the following
@partial(jax.jit, static_argnums=(1, 2))
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
but I get this error
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'numpy.ndarray'> for function loss is non-hashable.
I understand that from #6233 that this is by design but I was wondering what the workaround here is, as this seems like a very common use case where you have some fixed (input, output) training data pairs and some free variable.
Thanks for any tips!
EDIT: this is the error I get when I just try to use jax.jit
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function loss at /path/to/my/script:9 for jit, this concrete value was not available in Python because it depends on the value of the argument 'r'.`