I have written some Jax code to compute a log likelihood of a mixed logit model given some inputs - a vector of parameters parms
and some data X
. The log likelihood function embeds a nested fixed point algorithm. I am also using Jax to automatically differentiate the log likelihood (in forward mode using jax.jacfwd). This is generally working, in the sense that the Jax gradient matches a numerical approximation using finite differences for most values of the parameters. However, occasionally some set of parameter values causes the Jax-computed gradient to explode (which makes the optimization fail). Gradients computed using finite differences at these same parameter values have no issues.
I believe this is being caused by numerical instability due to the fixed point algorithm nested in the log-likelihood function. I found some tips for improving differentiation for these types of problems here, which gives an example of how to compute the gradient of only the fixed point function using jax.custom_jvp
and jax.custom_vjp
. However, I don't understand how to implement this in my code, given that I need to compute the derivative of the log-likelihood function which embeds the fixed point function.
Here is a minimum working example that illustrates how the fixed point algorithm is being used.
import jax.numpy as jnp
import numpy as np
import jax
import scipy
def calc_log_likelihood(parms,X,y):
# calculate market shares as the number of times each choice was chosen
shares = np.array([np.sum(y==i)/y.shape[0] for i in range(J)])
# function that calculates predicted probabilities for each (i,j) cell
def calc_predicted_prob(parms,X,deltas):
utility = jnp.exp(jnp.append(jnp.zeros(1),deltas)[None,:] + jnp.sum(parms[None,None,:]*X,axis=2))
prob = utility/jnp.sum(utility,axis=1)[:,None]
return prob
# function that finds vector of deltas that equates predicted market shares to actual market shares
def fixed_point(parms,X,shares):
J = X.shape[1]
tol = 1e-5
deltas_init = jnp.zeros(J-1)
def contraction(parms,X,deltas,shares):
pred_shares = jnp.mean(calc_predicted_prob(parms,X,deltas),axis=0)
return deltas + jnp.log(shares[1:J]) - jnp.log(pred_shares[1:J])
f = lambda z: contraction(parms,X,z,shares)
deltas_prev, deltas = deltas_init, f(deltas_init)
while jnp.linalg.norm(deltas_prev-deltas)>tol:
deltas_prev, deltas = deltas, f(deltas)
return deltas
# get vector of deltas
deltas = fixed_point(parms,X,shares)
# calculate predicted probabilities conditional on deltas and parms
prob = calc_predicted_prob(parms,X,deltas)
# get the set of "chosen" probabilities - i.e. the element from each row that corresponds to the choice recorded in the y vector
prob_chosen = prob[jnp.arange(prob.shape[0]),y]
# compute the log-likelihood
log_lik = jnp.log(prob_chosen)
return -jnp.sum(log_lik)
I = 1000
J = 10
K = 2
np.random.seed(1)
X = np.random.randn(I,J,K)
parms = np.array([0.1,0.2])
y = np.random.randint(0,10,size=I)
calc_log_likelihood(parms,X,y)
grad = jax.jacfwd(lambda z: calc_log_likelihood(z,X,y))
print(grad(parms))
print(scipy.optimize.approx_fprime(parms,lambda z: calc_log_likelihood(z,X,y),epsilon=1e-3))