Given a TensorFlow tf.while_loop
, how can I calculate the gradient of x_out
with respect to all weights of the network for each time step?
network_input = tf.placeholder(tf.float32, [None])
steps = tf.constant(0.0)
weight_0 = tf.Variable(1.0)
layer_1 = network_input * weight_0
def condition(steps, x):
return steps <= 5
def loop(steps, x_in):
weight_1 = tf.Variable(1.0)
x_out = x_in * weight_1
steps += 1
return [steps, x_out]
_, x_final = tf.while_loop(
condition,
loop,
[steps, layer_1]
)
Some notes
- In my network the condition is dynamic. Different runs are going to run the while loop a different amount of times.
- Calling
tf.gradients(x, tf.trainable_variables())
crashes withAttributeError: 'WhileContext' object has no attribute 'pred'
. It seems like the only possibility to usetf.gradients
within the loop is to calculate the gradient with respect toweight_1
and the current value ofx_in
/ time step only without backpropagating through time. - In each time step, the network is going to output a probability distribution over actions. The gradients are then needed for a policy gradient implementation.