I am trying to use @tf.function(jit_compile=True)
to create a TF graph with a while loop; below is its pseudocode. I'm not able to provide a functioning code since it contains a lot of dependencies.
Code 1
@tf.function(jit_compile=True)
def myfunction(inputs, model):
tf.while()
out3 = inputs
tf.while_loop(number_samples)
model = tf.keras.models.load_model()
out2 = model(out3)
out3 = function2(out2)
inputs = function3(out3)
return out3
Code 2
@tf.function(jit_compile=True)
def myfunction(inputs, model):
model = tf.keras.models.load_model()
tf.while()
out3 = inputs
out2 = model(out3)
out3 = function2(out2)
inputs = function3(out3)
return out3
The above code1 results in a memory explosion because I am calling the model inside the while loop. When I load the model outside both of the while loops, I get the error RuntimeError: Cannot get session inside Tensorflow graph function
. What is the best way to prevent memory explosion?
Edit 1: The inputs are tensors. The problem here is that I need to pass a large batch at once. For this, I made a while loop and thought that the while loop would work in parallel (keras model() can only process 32 samples at once). I am not sure why keras model does not have batch size as an input. In the above code is it preferable to directly load the weights and all the values and do the manual computation to get the outputs? In the case of code 2, will each mode have a graph because it is called inside while loop?
Edit 2: function 3 has gradient computation of out3 with respect to inputs.