1

I want to somehow maintain a list of constants in tf.while_loop that can support the following functions

  1. I am able to read and write (multiple times) a constant value at an index
  2. I am able to run tf.cond on it by checking its value at an index vs some constant

TensorArray would not work here since it does not support rewrites. What other options do I have?

AloneTogether
  • 25,814
  • 5
  • 20
  • 39
Sid Anand
  • 167
  • 1
  • 10
  • I suggest you migrate to TensorFlow2 as most people that will read you are probably using TF2 – Jav Dec 06 '21 at 08:37

1 Answers1

0

You could just define a normal Tensor and update it with tf.tensor_scatter_nd_update like this:

%tensorflow_version 1.x

import tensorflow as tf

data = tf.constant([1, 1, 1, 0, 1, 0, 1, 1, 0, 0], dtype=tf.float32)
data_tensor = tf.zeros_like(data)
tensor_size = data_tensor.shape[0]

init_state = (0, data_tensor)
condition = lambda i, _: i < tensor_size

def custom_body(i, tensor):
  special_index = 3 # index for which a value should be changed
  new_value = 8
  tensor = tf.where(tf.equal(i, special_index), 
                    tf.tensor_scatter_nd_update(tensor, [[special_index]], [new_value]),
                    tf.tensor_scatter_nd_update(tensor, [[i]], [data[i]*2]))

  return i + 1, tensor


body = lambda i, tensor: (custom_body(i, tensor))
_, final_result = tf.while_loop(condition, body, init_state)

with tf.Session() as sess:
  final_result_values = final_result.eval()

print(final_result_values)
[2. 2. 2. 8. 2. 0. 2. 2. 0. 0.]
AloneTogether
  • 25,814
  • 5
  • 20
  • 39