I'm attempting to fit parameters to a data set using optim()
in R. The objective function requires iterative root-solving for equation G
so that the predicted values p
brings the values for G
(nested within the objective function) to 0 (or as close as 0 to possible; I use 50 iterations of the bisection method for stability).
Here is the problem: I would really prefer to include an analytical gradient for optim()
, but I suspect it isn't possible for an iterated function. However, before I give up on the analytical gradient, I wanted to run this problem by everyone here and see if there might be a solution I'm overlooking. Any thoughts?
Note: before settling on the bisection method, I tried other root-solving methods, but all non-bracketing methods (Newton, etc.) seem to be unstable.
Below is a reproducible example of the problem. With the provided data set and the starting values for optim()
, the algorithm converges just fine without an analytical gradient, but it doesn't perform so well for other sets of data and starting values.
#the data set includes two input variables (x1 and x2)
#the response values are k successes out of n trials
x1=c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1.5, 1.5, 1.5, 1.5, 1.75, 1.75, 1.75, 1.75, 2, 2,
2, 2, 2.25, 2.25, 2.25, 2.25, 2.5, 2.5, 2.5, 2.5, 2.75, 2.75,
2.75, 2.75, 3, 3, 3, 3, 3.25, 3.25, 3.25, 3.25, 3.5, 3.5, 3.5,
3.5, 3.75, 3.75, 3.75, 3.75, 4, 4, 4, 4, 4.25, 4.25, 4.25, 4.25,
4.5, 4.5, 4.5, 4.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1.5, 1.5, 1.5, 1.5,
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.75, 1.75,
1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75,
1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75, 1.75,
1.75, 1.75, 1.75, 1.75, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2.25, 2.25, 2.25,
2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25,
2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25,
2.25, 2.25, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5,
2.5, 2.5, 2.5, 2.5, 2.5, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75,
2.75, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75,
2.75, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75, 2.75,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25,
3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25,
3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.25, 3.5, 3.5,
3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5,
3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5,
3.75, 3.75, 3.75, 3.75, 3.75, 3.75, 3.75, 3.75, 3.75, 3.75, 3.75,
3.75, 3.75, 3.75, 3.75, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
4, 4, 4, 4, 4, 4, 4, 4.25, 4.25, 4.25, 4.25, 4.25, 4.25, 4.25,
4.25, 4.25, 4.25, 4.25, 4.25, 4.25, 4.25, 4.25, 4.25, 4.25, 4.25,
4.25, 4.25, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5,
4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5, 4.5)
x2=c(0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0.1, 0.1, 0.1, 0.15, 0.15, 0.15,
0.15, 0.2, 0.2, 0.2, 0.2, 0.233, 0.233, 0.233, 0.267, 0.267,
0.267, 0.267, 0.3, 0.3, 0.3, 0.3, 0.333, 0.333, 0.333, 0.333,
0.367, 0.367, 0.367, 0.367, 0.4, 0.4, 0.4, 0.4, 0.433, 0.433,
0.433, 0.433, 0.467, 0.467, 0.467, 0.5, 0.5, 0.5, 0.5, 0.55,
0.55, 0.55, 0.55, 0.6, 0.6, 0.6, 0.6, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.267,
0.267, 0.267, 0.267, 0.333, 0.333, 0.333, 0.333, 0.4, 0.4, 0.4,
0.4, 0.467, 0.467, 0.467, 0.467, 0.55, 0.55, 0.55, 0.55, 0.15,
0.15, 0.15, 0.15, 0.233, 0.233, 0.233, 0.233, 0.3, 0.3, 0.3,
0.3, 0.367, 0.367, 0.367, 0.367, 0.433, 0.433, 0.433, 0.433,
0.5, 0.5, 0.5, 0.6, 0.6, 0.6, 0.6, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2,
0.2, 0.2, 0.267, 0.267, 0.267, 0.267, 0.333, 0.333, 0.333, 0.333,
0.4, 0.4, 0.4, 0.4, 0.467, 0.467, 0.467, 0.467, 0.55, 0.55, 0.55,
0.55, 0.15, 0.15, 0.15, 0.15, 0.233, 0.233, 0.233, 0.233, 0.3,
0.3, 0.3, 0.3, 0.367, 0.367, 0.367, 0.367, 0.433, 0.433, 0.433,
0.433, 0.5, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6, 0.6, 0.1, 0.1, 0.1,
0.1, 0.2, 0.2, 0.2, 0.267, 0.267, 0.267, 0.267, 0.333, 0.333,
0.333, 0.333, 0.4, 0.4, 0.4, 0.4, 0.467, 0.467, 0.467, 0.467,
0.55, 0.55, 0.55, 0.55, 0.15, 0.15, 0.15, 0.15, 0.233, 0.233,
0.233, 0.233, 0.3, 0.3, 0.3, 0.3, 0.367, 0.367, 0.367, 0.367,
0.433, 0.433, 0.433, 0.433, 0.5, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6,
0.6, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.267, 0.267, 0.267,
0.267, 0.333, 0.333, 0.333, 0.333, 0.4, 0.4, 0.4, 0.4, 0.467,
0.467, 0.467, 0.467, 0.55, 0.55, 0.55, 0.55, 0.15, 0.15, 0.15,
0.15, 0.233, 0.233, 0.233, 0.3, 0.3, 0.3, 0.3, 0.367, 0.367,
0.367, 0.367, 0.433, 0.433, 0.433, 0.433, 0.5, 0.5, 0.5, 0.6,
0.6, 0.6, 0.6, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.267,
0.267, 0.267, 0.267, 0.333, 0.333, 0.333, 0.333, 0.4, 0.4, 0.4,
0.4, 0.467, 0.467, 0.467, 0.467, 0.55, 0.55, 0.55, 0.55, 0.15,
0.15, 0.15, 0.15, 0.233, 0.233, 0.233, 0.233, 0.3, 0.3, 0.3,
0.3, 0.367, 0.367, 0.367, 0.367, 0.433, 0.433, 0.433, 0.433,
0.5, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6, 0.6, 0.1, 0.1, 0.1, 0.1, 0.2,
0.2, 0.2, 0.2, 0.267, 0.267, 0.267, 0.267, 0.333, 0.333, 0.333,
0.15, 0.15, 0.15, 0.15, 0.233, 0.233, 0.233, 0.233, 0.3, 0.3,
0.3, 0.3, 0.367, 0.367, 0.367, 0.367, 0.433, 0.433, 0.433, 0.433,
0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.267, 0.267, 0.267,
0.267, 0.333, 0.333, 0.333, 0.333, 0.4, 0.4, 0.4, 0.4, 0.15,
0.15, 0.15, 0.15, 0.233, 0.233, 0.233, 0.233, 0.3, 0.3, 0.3,
0.3, 0.367, 0.367, 0.367, 0.367, 0.433, 0.433, 0.433, 0.433)
k=c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 1, 3,
3, 3, 3, 3, 3, 4, 2, 5, 3, 4, 7, 8, 5, 4, 5, 5, 4, 5, 5, 5, 6,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 2, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 3, 2, 4, 1, 2,
3, 4, 2, 2, 4, 4, 3, 1, 2, 0, 3, 4, 5, 5, 0, 0, 0, 0, 0, 0, 1,
0, 0, 1, 1, 2, 1, 2, 2, 0, 3, 1, 0, 2, 4, 6, 5, 5, 4, 5, 5, 5,
1, 0, 0, 0, 2, 1, 0, 1, 3, 2, 1, 1, 3, 4, 3, 4, 5, 5, 5, 5, 8,
6, 7, 6, 6, 5, 7, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 2, 1, 1, 3, 3,
2, 1, 3, 6, 2, 5, 3, 3, 5, 6, 5, 5, 5, 1, 0, 1, 1, 2, 1, 1, 1,
3, 4, 2, 5, 5, 3, 4, 4, 6, 4, 6, 5, 6, 5, 5, 5, 5, 4, 5, 5, 0,
0, 0, 0, 0, 2, 0, 2, 3, 3, 3, 2, 3, 3, 1, 4, 4, 4, 4, 3, 5, 6,
5, 5, 5, 5, 5, 1, 4, 1, 2, 2, 3, 4, 2, 5, 5, 5, 5, 5, 4, 5, 7,
6, 7, 6, 5, 5, 5, 7, 5, 5, 5, 5, 5, 0, 1, 0, 0, 3, 2, 3, 3, 1,
2, 2, 2, 4, 2, 3, 2, 5, 5, 5, 5, 4, 6, 5, 6, 5, 5, 6, 5, 3, 5,
2, 4, 5, 3, 5, 5, 6, 4, 4, 5, 5, 5, 6, 6, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 0, 0, 2, 0, 3, 2, 3, 2, 3, 4, 3, 4, 5, 5, 5, 5, 6, 4,
6, 4, 5, 7, 5, 5, 5, 6, 5, 5, 2, 3, 4, 4, 4, 4, 5, 5, 5, 6, 5,
5, 5, 5, 5, 4, 6, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 1, 0, 2, 0,
3, 5, 2, 2, 4, 5, 4, 5, 6, 6, 4, 5, 4, 5, 4, 5, 5, 5, 5, 5, 5,
6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 4, 1, 4, 4, 4, 4, 4, 3, 6, 5,
4, 3, 5, 4, 5, 6, 6, 5, 6, 5, 4, 5, 5, 5, 6, 5, 5, 5, 11, 5,
12, 5, 5, 5, 5, 4, 5, 5, 5)
n=c(5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5,
6, 5, 6, 5, 5, 5, 5, 7, 5, 6, 8, 8, 6, 5, 6, 5, 5, 5, 5, 5, 6,
5, 5, 5, 5, 7, 11, 8, 7, 5, 5, 5, 5, 7, 5, 5, 5, 5, 5, 5, 5,
4, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 5, 5, 4, 5, 5, 5, 6, 5, 5, 5, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 6, 6, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 7, 6, 7, 6, 5, 5, 5, 5,
5, 5, 5, 5, 5, 6, 5, 6, 6, 5, 5, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5,
8, 6, 7, 6, 6, 5, 7, 5, 5, 5, 5, 6, 5, 5, 5, 7, 7, 6, 5, 6, 5,
5, 5, 5, 6, 6, 4, 6, 6, 5, 5, 6, 6, 5, 5, 5, 5, 5, 5, 7, 5, 5,
4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 4, 6, 5, 6, 5, 5, 5, 5, 4, 5, 5,
5, 5, 6, 6, 5, 6, 5, 4, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 6, 5, 5,
6, 5, 5, 5, 5, 5, 5, 6, 5, 6, 7, 4, 6, 5, 5, 5, 5, 5, 5, 4, 5,
7, 6, 7, 6, 5, 5, 5, 7, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 5,
5, 5, 5, 5, 5, 4, 5, 6, 5, 5, 5, 5, 5, 7, 5, 6, 5, 5, 6, 5, 5,
5, 5, 5, 5, 5, 5, 6, 6, 5, 5, 5, 5, 5, 6, 6, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6,
5, 6, 5, 6, 7, 5, 5, 5, 6, 5, 5, 4, 5, 5, 5, 5, 6, 5, 5, 5, 6,
5, 5, 5, 5, 5, 5, 6, 5, 5, 5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 4, 5, 5, 5, 5, 5, 5, 5, 7, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
5, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6,
5, 5, 5, 5, 5, 5, 6, 6, 5, 6, 5, 5, 5, 5, 5, 6, 5, 5, 5, 11,
5, 12, 5, 5, 5, 5, 4, 5, 5, 5)
#low_high contains the lower and upper bounds for the bisection method
low_high=vector("list",2)
low_high[["low"]]=rep(0,length(x1))
low_high[["high"]]=rep(1,length(x1))
low_high_list=rep(list(low_high),50)
ll=function(theta)
{
names(theta)=c("b1","m1","b2","m2")
b1=theta[["b1"]]
m1=theta[["m1"]]
b2=theta[["b2"]]
m2=theta[["m2"]]
#bisection function is used to find y which makes G=0
bisection_function=function(prv,nxt)
{
low_high=prv
#G and y are both vectors of the length of the data set (in this example, 469)
y=(low_high[["low"]]+low_high[["high"]])/2
G=-1+(x1/((log(-y/(y-1))-b1)/m1))+(x2/((log(-y/(y-1))-b2)/m2))
low_high[["low"]][G>0]=y[G>0]
low_high[["high"]][G<0]=y[G<0]
return(low_high)
}
#Reduce is the fastest method I've found so far
#(in other words, if there is a better way, I'm certainly open to suggestions!)
low_high=Reduce(bisection_function,low_high_list)
p=(low_high[["low"]]+low_high[["high"]])/2
#sum of log likelihood for binomial distribution
loglik=sum(log((gamma(1+n)/(gamma(1+k)*(gamma(1+n-k))))*((p^k)*((1-p)^(n-k)))))
return(loglik)
}
theta.start=c(b1=-10,m1=10,b2=-10,m2=10)
mle=optim(theta.start,ll,control=list(fnscale=-1),hessian=TRUE)
Thanks!!