2

When I pass an object created using the following function function into a jax.lax.scan function:

def logdensity_create(model, centeredness = None, varname = None):
    if centeredness is not None:
        model = reparam(model, config={varname: LocScaleReparam(centered= centeredness)})
          
    init_params, potential_fn_gen, *_ = initialize_model(jax.random.PRNGKey(0),model,dynamic_args=True)
    logdensity = lambda position: -potential_fn_gen()(position)
    initial_position = init_params.z
    return (logdensity, initial_position)

I get the following error (on passing the logdensity to an iterative function created using jax.lax.scan):

TypeError: Value .logdensity_create.. at 0x13fca7d80> with type  is not a valid JAX type

How can I resolve this error?

jakevdp
  • 77,104
  • 11
  • 125
  • 160
imk
  • 133
  • 6
  • The ideas in this [Github Issue](https://github.com/google/jax/issues/1443) are probably helpful – Nin17 Jul 10 '23 at 15:25

1 Answers1

2

I would probably do this via jax.tree_util.Partial, which wraps callables in a PyTree for compatibility with jit and other transformations:

logdensity = jax.tree_util.Partial(lambda position: -potential_fn_gen()(position))
jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • The flexibility between JIT and AOT is a fantastic feature of JAX. Is the Partial function cached if passed to several JIT'ed functions? – DavidJ Jul 27 '23 at 12:22