0

I am trying to implement a paper called conditionally-structured Gaussian variational inference (CS-GVA).

Firstly when implementing variationaL inference algorithms using reparameterization, you express the model parameters theta as a deterministic transform concerning the variational parameters, e.g. the simplest possible example is

def sample_Gaussian(mu, sigma, eps): 
    return mu + sigma * eps 

where eps is generated from random.normal. The required reparameterization for CS-GVA is slightly more involved:

import jax.numpy as jnp 
from jax import random, jit, grad, vjp, vmap
from jax.scipy.linalg import solve_triangular
from functools import partial 
from typing import Callable

def reparam(
    v_params: dict, 
    eps: jnp.array, 
    precision_mask: jnp.array, 
    diag_indices: jnp.array, 
    p: int
): 
    """
    uses stochastic noise and current variational parameters to generate set of model parameters
    
    when using this algorithm for a choice of model, this function has multiple static 
    arguments, and needs to be `partial`ly evaluated with static arguments for `precision_mask`, 
    `diag_indices`, `p` before being used

    :param v_params: dictionary object, where each key-value pair corresponds to a jnp.array 
        of numbers
    :param eps: noise of len(p) drawn from standard normal 
    :param precision_mask: provided indices of non-zero elements of the precision matrix 
    :param diag_indices: provides indices for the diagonal elements of the precision matrix
    :param p: the number of "global" model parameters in the specified probability model
    """
    T = jnp.zeros((p, p))
    T = T.at[precision_mask].set(v_params["T"])
    Tstar = T.at[diag_indices].set(jnp.exp(T[diag_indices]))
    Teps = solve_triangular(Tstar, eps)
    theta = v_params["mu"] + Teps
    return theta, Teps 

def cond_reparam(
    v_params: dict, 
    eps: jnp.array, 
    theta: jnp.array, 
    Teps: jnp.array, 
    precision_mask: jnp.array, 
    diag_indices: jnp.array,
    p: int
):
    """
    given global model parameter information, current variational 
    parameters and stochastic noise, generates "local" model 
    parameters 

    when using this algorithm for a choice of model, this function has 
    multiple static arguments, and needs to be `partial`ly evaluated with 
    static arguments for `precision_mask`, `diag_indices`, `p` before being used

    :param v_params: dictionary of variational parameters. 
        must have keys called "d", "D", "f", "F"
    :param eps: stochastic noise of len(p)
    :param theta: subset of model parameters that the output of this 
        function is dependent on
    :param Teps: function involving subset of model parameters that 
        the output of this function is dependent on 
    :param precision_mask: indices of non-zero entries of precision matrix  
    :param diag_indices: indices of the diagonal entries of precision 
        matrix 
    :param p: number of local model parameters to generate
    """
    vecT = v_params["f"] + (v_params["F"] @ theta.T).T
    T = jnp.zeros((p, p))
    T = T.at[precision_mask].set(vecT)
    Tstar = T.at[diag_indices].set(jnp.exp(T[diag_indices]))
    b = eps - (v_params["D"] @ Teps.T).T
    scale_transform = solve_triangular(Tstar, b)
    delta = v_params["d"] + scale_transform + solve_triangular(Tstar, eps)
    return delta

def cs_gva(
    v_params,
    eps_g, 
    eps_l 
):
    """
    function that calls the `reparam ` function for global params 
    then calls the `cond_reparam` function using some output from the `reparam` function 

    :param v_params: dictionary of variational parameters  
    :param eps_g: stochastic noise for the global parameters
    :param eps_l: stochastic noise for the local parameters 
    """
    theta, Teps = reparam(v_params, eps_g)
    delta = cond_reparam(v_params, eps_l, theta, Teps)
    return jnp.hstack([theta, delta])

Running the function looks something like:

p = jnp.array([5, 537])
mu = jnp.zeros(p[0])
T = jnp.zeros(int(p[0]*(p[0]+1)/2))
precision_mask = jnp.mask_indices(p[0], jnp.triu)
diag_indices = jnp.diag_indices(p[0])

# partially evaluating the function 
reparam = jit(partial(
    reparam, 
    precision_mask=precision_mask, 
    diag_indices=diag_indices,
    p=p[0]
))

d = jnp.zeros(shape=(p[1],))
D = jnp.zeros(shape=(p[1], p[0]))
f = jnp.zeros(shape=(p[1],))
F = jnp.zeros(shape=(p[1], p[0]))

precision_mask2 = jnp.diag_indices(p[1])
diag_indices2 = jnp.diag_indices(p[1])

# partially evaluating the function 
cond_reparam = jit(partial(
    cond_reparam, 
    precision_mask=precision_mask2, 
    diag_indices=diag_indices2,
    p=p[1]
))

v_params = {
    "mu": mu, 
    "T": T,
    "d": d, 
    "D": D,  
    "f": f, 
    "F": F
}

key = random.PRNGKey(2022)
eps = random.normal(key, (p.sum(),))

eps_g = eps[:p[0]]
eps_l = eps[p[0]:]

theta = cs_gva(
    v_params, 
    eps_g, 
    eps_l
)

Calling this function works and does what is expected. The next challenge to use automatic differentiation to evaluate a vector-Jacobian product, where the input to the vjp_function is some function applied to theta (mathematically, it is the vector of derivatives with respect to the log probability, for simplification here, I do not apply any function to theta before using it as the cotangent to the vjp_function):

def calc_vjp(
    generative: Callable, 
    dlogp: Callable,
    v_params: dict, 
    eps_g: jnp.array, 
    eps_l: jnp.array
): 
    """
    calculate the vector-Jacobian product of the Jacobian matrix of the `generative`
    function with respect to the parameters `v_params`, applied to the a cotangent 
    vector given by the output of the cotangent vector

    :param generative: the function that takes as input the variational parameters 
        and noise and outputs the model parameters 
    :param dlogp: evaluates the derivative of the log probability at the current 
        model parameters 
    :param v_params: value of the variational parameters, stored as a dictionary
    :param eps_g: noise generated for the global parameters
    :param eps_l: noise generated for the local parameters
    """
    theta, vjp_function = vjp(
        generative, 
        v_params, 
        eps_g, 
        eps_l, 
        has_aux=False
    )
    return vjp_function(theta)[0], theta

grads, theta = calc_vjp(
    cs_gva, 
    dlogp, 
    v_params, 
    eps_g, 
    eps_l
)

My struggle at the moment is that I want to do something called "importance weighting", that is I want to draw multiple eps vectors to generate multiple sets of model parameters at each iteration, weight them respectively, and use some weighted combination to estimate the gradient updates (hopefully this approasch leads to more stable gradient estimation):

eps_v = random.normal(key, (100, p.sum()))

eps_v_g = eps_v[:, :p[0]]
eps_v_l = eps_v[:, p[0]:]

cs_gva_vmapped = jit(vmap(
    cs_gva, 
    in_axes=(None, 0, 0)
))

theta = cs_gva_vmapped(
    v_params, 
    eps_v_g, 
    eps_v_l
)

Here, it is obvious how to vmap the cs_gva function, and it works as expected. But any attempt to do the same to the calc_vjp function doesn't work, and I don't get why.

vmap(
    calc_vjp, 
    in_axes=(None, None, None, 0, 0)
)(cs_gva, dlogp, v_params, eps_g, eps_l)

Any help would be greatly appreciated.

hasco641
  • 69
  • 5
  • 1
    I find it really hard to understand what's happening given such an abstract description of the problem, and I doubt you'll find anyone who can answer this question as it's currently written. You might consider putting together a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example). – jakevdp Oct 15 '22 at 14:36
  • Hi, I agree. I have added an example, let me know if it is long-winded, but I had trouble constructing a simpler example where I got the same error. – hasco641 Oct 21 '22 at 02:57

0 Answers0