I am trying to get the Jacobian for a simple parameterization function within JAX. The code is as follows:
# imports
import jax
import jax.numpy as jnp
from jax import random
# simple parameterization function
def reparameterize(v_params):
theta = v_params[0] + jnp.exp(v_params[1]) * eps
return theta
Suppose I initialize eps
to be a vector of shape (3,)
and v_params
to be of shape (3, 2)
:
key = random.PRNGKey(2022)
eps = random.normal(key, shape=(3,))
key, _ = random.split(key)
v_params = random.normal(key, shape=(3, 2))
I want the Jacobian to be an array of shape (3, 2)
but by using
jacobian(vmap(reparameterize))(v_params)
returns an array of shape (3, 3, 3, 2)
. If I re-initialize with only a single eps
:
key, _ = random.split(key)
eps = random.normal(key, shape=(1, ))
key, _ = random.split(key)
v_params = random.normal(key, shape=(2, ))
and call jacobian(reparameterize)(v_params)
I get what I want, e.g., an array of shape (2, )
. Effectively looping over all eps
and stacking the results of each Jacobian gives me the desired Jacobian (and shape). What am I missing here? Thanks for your help!