Can someone please explain (or point me to the relevant place in the documentation that I've missed) how to properly update a tf.Variable()
in a tf.while_loop
? I am trying to update variables in the loop that will store some information until the next iteration of the loop using the assign()
method. However, this isn't doing anything.
As the values of mu_tf
and sigma_tf
are being updated by the minimizer, while step_mu
isn't, I am obviously doing something wrong, but I don't understand what it is. Specifically, I guess I should say that I know assign()
does not do anything until it is executed when the graph is run, so I know that I can do
sess.run(step_mu.assign(mu_tf))
and that will update step_mu
, but I want to do this in the loop correctly. I don't understand how to add an assign
operation to the body of the loop.
A simplified working example of what I'm doing follows here:
import numpy as np
import tensorflow as tf
mu_true = 0.5
sigma_true = 1.5
n_events = 100000
# Placeholders
X = tf.placeholder(dtype=tf.float32)
# Variables
mu_tf = tf.Variable(initial_value=tf.random_normal(shape=[], mean=0., stddev=0.1,
dtype=tf.float32),
dtype=tf.float32)
sigma_tf = tf.Variable(initial_value=tf.abs(tf.random_normal(shape=[], mean=1., stddev=0.1,
dtype=tf.float32)),
dtype=tf.float32,
constraint=lambda x: tf.abs(x))
step_mu = tf.Variable(initial_value=-99999., dtype=tf.float32)
step_loss = tf.Variable(initial_value=-99999., dtype=tf.float32)
# loss function
gaussian_dist = tf.distributions.Normal(loc=mu_tf, scale=sigma_tf)
log_prob = gaussian_dist.log_prob(value=X)
negative_log_likelihood = -1.0 * tf.reduce_sum(log_prob)
# optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.1)
# sample data
x_sample = np.random.normal(loc=mu_true, scale=sigma_true, size=n_events)
# Construct the while loop.
def cond(step):
return tf.less(step, 10)
def body(step):
# gradient step
train_op = optimizer.minimize(loss=negative_log_likelihood)
# update step parameters
with tf.control_dependencies([train_op]):
step_mu.assign(mu_tf)
return tf.add(step,1)
loop = tf.while_loop(cond, body, [tf.constant(0)])
# Execute the graph
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step_loss = sess.run(fetches=negative_log_likelihood, feed_dict={X: x_sample})
print('Before loop:\n')
print('mu_tf: {}'.format(sess.run(mu_tf)))
print('sigma_tf: {}'.format(sess.run(sigma_tf)))
print('step_mu: {}'.format(sess.run(step_mu)))
print('step_loss: {}\n'.format(step_loss))
sess.run(fetches=loop, feed_dict={X: x_sample})
print('After loop:\n')
print('mu_tf: {}'.format(sess.run(mu_tf)))
print('sigma_tf: {}'.format(sess.run(sigma_tf)))
print('step_mu: {}'.format(sess.run(step_mu)))
print('step_loss: {}'.format(step_loss))