1

I have a vector-jacobian product that I want to compute.

The function func takes four arguments, the final two of which are static:

def func(variational_params, e, A, B):
    ...
    return model_params, dlogp, ...

The function jits perfectly fine via

func_jitted = jit(func, static_argnums=(2, 3))

The primals are the variational_params, and the cotangents are dlogp (the second output of the function).

Calculating the vector-jacobian product naively (by forming the jacobian) works fine:

jacobian_func = jacobian(func_jitted, argnums=0, has_aux=True)
jacobian_jitted = jit(jacobian_func, static_argnums=(2, 3))
jac, func_output = jacobian_jitted(variational_params, e, A, B)
naive_vjp = func_output.T @ jac 

When trying to form the vjp in an efficient manner via

f_eval, vjp_function, aux_output = vjp(func_jitted, variational_params, e, A, B, has_aux=True)

I get the following error:

ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.ad.JVPTracer'> for function func is non-hashable.

I am a little confused as the function func jitted perfectly fine... there is no option for adding static_argnums to the vjp function, so I am not too sure what this means.

hasco641
  • 69
  • 5

1 Answers1

1

For higher-level transformation APIs like jit, JAX generally provides a mechanism like static_argnums or argnums to allow specification of static vs. dynamic variables.

For lower-level transformation routines like jvp and vjp, these mechanisms are not provided, but you can still accomplish the same thing by passing partially-evaluated functions. For example:

from functools import partial

f_eval, vjp_function, aux_output = vjp(partial(func_jitted, A=A, B=B), variational_params, e, has_aux=True)

This is effectively how transformation parameters like argnums and static_argnums are implemented under the hood.

jakevdp
  • 77,104
  • 11
  • 125
  • 160
  • Thanks for the answer, Jake. This works perfectly but has now led to another hiccup. The cotangents, given by `aux_output[0]` and the input to the function `e` may be batches and share the same value in the leading dimension, and I want to vmap over these batches. Using the naive implementation provided in the question, it is easy to wrap as the following `vmap`ped function: ```def naive_vjp(func_output, jac): return func_output.T @ jac; naive_vjp = vmap(naive_vjp, in_axes=(0, 0))```. Unfortunately, `vmap`ping the `vjp_func` given by your solution does not work because I would have to `vmap` – hasco641 Oct 15 '22 at 07:55
  • `func` with respect to the argument `e`. There are two questions that I have: I am currently working on the assumption that in order to `vmap` `vjp_func` over `e` that you would have to be able to `vmap` `func` over `e`; is this true? Secondly in order to (re)-write the function in a way that it can be `vmap`ped, I am running into issues with having to `vmap` nested functions over different arguments: see this [question here](https://stackoverflow.com/questions/74077964/vmapping-a-top-level-function-when-there-are-multiple-nested-functions-that-re) – hasco641 Oct 15 '22 at 08:45