The goal is to implement a recurrent function in TensorFlow to filter a signal over time.
The input
is later presented as a 5-D tensor of form [batch, in_depth, in_height, in_width, in_channels]
. I want to use tf.while_loop
to iterate over in_depth
and reassign values depended on values of previous time steps. However, I fail to reassign variable values within the loop.
In order to simplify the problem, I have create a 1-D version of the problem:
def condition(i, signal):
return tf.less(i, signal.shape[0])
def operation(i, signal):
signal = tf.get_variable("signal")
signal = signal[i].assign(signal[i-1]*2)
i = tf.add(i, 1)
return (i, signal)
with tf.variable_scope("scope"):
i = tf.constant(1)
init = tf.constant_initializer(0)
signal = tf.get_variable("scope", [4], tf.float32, init, trainable = False)
signal = tf.assign(signal[0], 1.2)
with tf.variable_scope("scope", reuse = True):
loops_vars = [i, signal]
i, signal = tf.while_loop(condition, operation, loop_vars, back_prop = False)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
i, signal = session.run([i, signal])
tf.assign
returns an operation, which has to be run in a session in order to be evaluated (see here for further details).
I expected, that TensorFlow would chain the operations within the loop and hence execute the assignments once I run a session and request signal
. However, when I execute the given code and print the result, signal
contatins [1.2, 0, 0, 0]
and i
contains (as expected) 4
.
What is my misconception here and how can I change the code such that the values of signal
are reassigned?