I'm trying to implement a theano version of conjugated gradients for Hessian-Free optimisation as described in this paper: http://icml2010.haifa.il.ibm.com/papers/458.pdf.
The termination condition used for the conjugate gradient algo is that the value of the objective Phi has not decreased very much relative to its value some number of iterations ago.
i.e (Phi(t) - Phi(t-k))/Phi(t) < k eps. where eps is a tolerance.
I'm using Theano Scan to actually implement the conjugated gradients and so the stopping criterion will require me to feed in as a tap, the value of Phi K iterations ago.
For fixed K this is easy and shown in my code below. However, the value of K is supposed to change as a function of the number of iterations. K = max(10, 0.1*iters). This would require a tap that was a shared variable? is that possible? How do you do it?
I've tried using a theano variable for K within the step function but then scan cant use that K as input.
here's my code:
r_init = b_vec - mat_vec_func(x_init)
d_init = r_init
delta_0 = T.dot(r_init.transpose(), r_init)
# initial value of the CG objective
phi_0 = -T.dot(x_init.transpose(), r_init)
phis_0 = phi_0.dimshuffle(['x'])
def conj_step(iters, x, d, r, delta_new, phi_tm5, phi_tm1):
q = mat_vec_func(d)
alpha = delta_new/(T.dot(d.transpose(), q))
x = x + alpha*d
r = ifelse(T.eq(iters%50, 0), b_vec - mat_vec_func(x), r - alpha*q)
delta_old = delta_new
delta_new = T.dot(r.transpose(), r)
beta = delta_new/delta_old
d = r + beta*d
# calculate the CG objective to use for the stopping criterion.
phi = -T.dot(x.transpose(), r)
k = max(k, 0.1*iters.eval())
return [x, d, r, delta_new, phi], theano.scan_module.until(T.le((phi - phi_tm5)/phi , tol))
xdrd, _ = theano.scan(conj_step,
sequences = [T.arange(max_iter)],
outputs_info=[x_init, d_init, r_init, delta_0,
dict(initial=phis_0,taps=[-5, -1])],
)
The conjugated gradients algorithm is solving the equation Ax = b and mat_vec_func is just a function to calculate Ax.
In my example K is hardcoded to 5 but I would like to be able to make this vary.
Thanks!