I am trying to implement a paper called conditionally-structured Gaussian variational inference (CS-GVA).
Firstly when implementing variationaL inference algorithms using reparameterization, you express the model parameters theta as a deterministic transform concerning the variational parameters, e.g. the simplest possible example is
def sample_Gaussian(mu, sigma, eps):
return mu + sigma * eps
where eps
is generated from random.normal
. The required reparameterization for CS-GVA is slightly more involved:
import jax.numpy as jnp
from jax import random, jit, grad, vjp, vmap
from jax.scipy.linalg import solve_triangular
from functools import partial
from typing import Callable
def reparam(
v_params: dict,
eps: jnp.array,
precision_mask: jnp.array,
diag_indices: jnp.array,
p: int
):
"""
uses stochastic noise and current variational parameters to generate set of model parameters
when using this algorithm for a choice of model, this function has multiple static
arguments, and needs to be `partial`ly evaluated with static arguments for `precision_mask`,
`diag_indices`, `p` before being used
:param v_params: dictionary object, where each key-value pair corresponds to a jnp.array
of numbers
:param eps: noise of len(p) drawn from standard normal
:param precision_mask: provided indices of non-zero elements of the precision matrix
:param diag_indices: provides indices for the diagonal elements of the precision matrix
:param p: the number of "global" model parameters in the specified probability model
"""
T = jnp.zeros((p, p))
T = T.at[precision_mask].set(v_params["T"])
Tstar = T.at[diag_indices].set(jnp.exp(T[diag_indices]))
Teps = solve_triangular(Tstar, eps)
theta = v_params["mu"] + Teps
return theta, Teps
def cond_reparam(
v_params: dict,
eps: jnp.array,
theta: jnp.array,
Teps: jnp.array,
precision_mask: jnp.array,
diag_indices: jnp.array,
p: int
):
"""
given global model parameter information, current variational
parameters and stochastic noise, generates "local" model
parameters
when using this algorithm for a choice of model, this function has
multiple static arguments, and needs to be `partial`ly evaluated with
static arguments for `precision_mask`, `diag_indices`, `p` before being used
:param v_params: dictionary of variational parameters.
must have keys called "d", "D", "f", "F"
:param eps: stochastic noise of len(p)
:param theta: subset of model parameters that the output of this
function is dependent on
:param Teps: function involving subset of model parameters that
the output of this function is dependent on
:param precision_mask: indices of non-zero entries of precision matrix
:param diag_indices: indices of the diagonal entries of precision
matrix
:param p: number of local model parameters to generate
"""
vecT = v_params["f"] + (v_params["F"] @ theta.T).T
T = jnp.zeros((p, p))
T = T.at[precision_mask].set(vecT)
Tstar = T.at[diag_indices].set(jnp.exp(T[diag_indices]))
b = eps - (v_params["D"] @ Teps.T).T
scale_transform = solve_triangular(Tstar, b)
delta = v_params["d"] + scale_transform + solve_triangular(Tstar, eps)
return delta
def cs_gva(
v_params,
eps_g,
eps_l
):
"""
function that calls the `reparam ` function for global params
then calls the `cond_reparam` function using some output from the `reparam` function
:param v_params: dictionary of variational parameters
:param eps_g: stochastic noise for the global parameters
:param eps_l: stochastic noise for the local parameters
"""
theta, Teps = reparam(v_params, eps_g)
delta = cond_reparam(v_params, eps_l, theta, Teps)
return jnp.hstack([theta, delta])
Running the function looks something like:
p = jnp.array([5, 537])
mu = jnp.zeros(p[0])
T = jnp.zeros(int(p[0]*(p[0]+1)/2))
precision_mask = jnp.mask_indices(p[0], jnp.triu)
diag_indices = jnp.diag_indices(p[0])
# partially evaluating the function
reparam = jit(partial(
reparam,
precision_mask=precision_mask,
diag_indices=diag_indices,
p=p[0]
))
d = jnp.zeros(shape=(p[1],))
D = jnp.zeros(shape=(p[1], p[0]))
f = jnp.zeros(shape=(p[1],))
F = jnp.zeros(shape=(p[1], p[0]))
precision_mask2 = jnp.diag_indices(p[1])
diag_indices2 = jnp.diag_indices(p[1])
# partially evaluating the function
cond_reparam = jit(partial(
cond_reparam,
precision_mask=precision_mask2,
diag_indices=diag_indices2,
p=p[1]
))
v_params = {
"mu": mu,
"T": T,
"d": d,
"D": D,
"f": f,
"F": F
}
key = random.PRNGKey(2022)
eps = random.normal(key, (p.sum(),))
eps_g = eps[:p[0]]
eps_l = eps[p[0]:]
theta = cs_gva(
v_params,
eps_g,
eps_l
)
Calling this function works and does what is expected. The next challenge to use automatic differentiation to evaluate a vector-Jacobian product, where the input to the vjp_function
is some function applied to theta
(mathematically, it is the vector of derivatives with respect to the log probability, for simplification here, I do not apply any function to theta
before using it as the cotangent to the vjp_function
):
def calc_vjp(
generative: Callable,
dlogp: Callable,
v_params: dict,
eps_g: jnp.array,
eps_l: jnp.array
):
"""
calculate the vector-Jacobian product of the Jacobian matrix of the `generative`
function with respect to the parameters `v_params`, applied to the a cotangent
vector given by the output of the cotangent vector
:param generative: the function that takes as input the variational parameters
and noise and outputs the model parameters
:param dlogp: evaluates the derivative of the log probability at the current
model parameters
:param v_params: value of the variational parameters, stored as a dictionary
:param eps_g: noise generated for the global parameters
:param eps_l: noise generated for the local parameters
"""
theta, vjp_function = vjp(
generative,
v_params,
eps_g,
eps_l,
has_aux=False
)
return vjp_function(theta)[0], theta
grads, theta = calc_vjp(
cs_gva,
dlogp,
v_params,
eps_g,
eps_l
)
My struggle at the moment is that I want to do something called "importance weighting", that is I want to draw multiple eps
vectors to generate multiple sets of model parameters at each iteration, weight them respectively, and use some weighted combination to estimate the gradient updates (hopefully this approasch leads to more stable gradient estimation):
eps_v = random.normal(key, (100, p.sum()))
eps_v_g = eps_v[:, :p[0]]
eps_v_l = eps_v[:, p[0]:]
cs_gva_vmapped = jit(vmap(
cs_gva,
in_axes=(None, 0, 0)
))
theta = cs_gva_vmapped(
v_params,
eps_v_g,
eps_v_l
)
Here, it is obvious how to vmap
the cs_gva
function, and it works as expected. But any attempt to do the same to the calc_vjp
function doesn't work, and I don't get why.
vmap(
calc_vjp,
in_axes=(None, None, None, 0, 0)
)(cs_gva, dlogp, v_params, eps_g, eps_l)
Any help would be greatly appreciated.