1

I have written some Jax code to compute a log likelihood of a mixed logit model given some inputs - a vector of parameters parms and some data X. The log likelihood function embeds a nested fixed point algorithm. I am also using Jax to automatically differentiate the log likelihood (in forward mode using jax.jacfwd). This is generally working, in the sense that the Jax gradient matches a numerical approximation using finite differences for most values of the parameters. However, occasionally some set of parameter values causes the Jax-computed gradient to explode (which makes the optimization fail). Gradients computed using finite differences at these same parameter values have no issues.

I believe this is being caused by numerical instability due to the fixed point algorithm nested in the log-likelihood function. I found some tips for improving differentiation for these types of problems here, which gives an example of how to compute the gradient of only the fixed point function using jax.custom_jvp and jax.custom_vjp. However, I don't understand how to implement this in my code, given that I need to compute the derivative of the log-likelihood function which embeds the fixed point function.

Here is a minimum working example that illustrates how the fixed point algorithm is being used.

import jax.numpy as jnp 
import numpy as np
import jax
import scipy

def calc_log_likelihood(parms,X,y):
    
    # calculate market shares as the number of times each choice was chosen
    shares = np.array([np.sum(y==i)/y.shape[0] for i in range(J)])
    
    # function that calculates predicted probabilities for each (i,j) cell
    def calc_predicted_prob(parms,X,deltas):
        utility = jnp.exp(jnp.append(jnp.zeros(1),deltas)[None,:] + jnp.sum(parms[None,None,:]*X,axis=2))
        prob = utility/jnp.sum(utility,axis=1)[:,None]
        return prob

    # function that finds vector of deltas that equates predicted market shares to actual market shares
    def fixed_point(parms,X,shares):
        J = X.shape[1]
        tol = 1e-5
        deltas_init = jnp.zeros(J-1)
        def contraction(parms,X,deltas,shares):
            pred_shares = jnp.mean(calc_predicted_prob(parms,X,deltas),axis=0)
            return deltas + jnp.log(shares[1:J]) - jnp.log(pred_shares[1:J])
        f = lambda z: contraction(parms,X,z,shares)

        deltas_prev, deltas = deltas_init, f(deltas_init)
        while jnp.linalg.norm(deltas_prev-deltas)>tol:
            deltas_prev, deltas = deltas, f(deltas)
        return deltas
    
    # get vector of deltas
    deltas = fixed_point(parms,X,shares)

    # calculate predicted probabilities conditional on deltas and parms
    prob = calc_predicted_prob(parms,X,deltas)
    
    # get the set of "chosen" probabilities - i.e. the element from each row that corresponds to the choice recorded in the y vector
    prob_chosen = prob[jnp.arange(prob.shape[0]),y]
    
    # compute the log-likelihood
    log_lik = jnp.log(prob_chosen)
    return -jnp.sum(log_lik)

I = 1000
J = 10
K = 2

np.random.seed(1)
X = np.random.randn(I,J,K)

parms = np.array([0.1,0.2])
y = np.random.randint(0,10,size=I)


calc_log_likelihood(parms,X,y)

grad = jax.jacfwd(lambda z: calc_log_likelihood(z,X,y))
print(grad(parms))

print(scipy.optimize.approx_fprime(parms,lambda z: calc_log_likelihood(z,X,y),epsilon=1e-3))
Ben K
  • 11
  • 2
  • Do you have an example of inputs for which the gradient blows up? I ran your example code and the `grad` output matches the finite difference output. – jakevdp Jul 07 '23 at 16:51

0 Answers0