3

I am trying to use @tf.function(jit_compile=True) to create a TF graph with a while loop; below is its pseudocode. I'm not able to provide a functioning code since it contains a lot of dependencies.

Code 1
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  tf.while()
    out3 = inputs
    tf.while_loop(number_samples)
      model = tf.keras.models.load_model()
      out2 = model(out3)
      out3 = function2(out2)
    inputs = function3(out3)
  return out3
  
Code 2
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  model = tf.keras.models.load_model()
  tf.while()
    out3 = inputs
    out2 = model(out3)
    out3 = function2(out2)
    inputs = function3(out3)
  return out3
  

The above code1 results in a memory explosion because I am calling the model inside the while loop. When I load the model outside both of the while loops, I get the error RuntimeError: Cannot get session inside Tensorflow graph function. What is the best way to prevent memory explosion?

Edit 1: The inputs are tensors. The problem here is that I need to pass a large batch at once. For this, I made a while loop and thought that the while loop would work in parallel (keras model() can only process 32 samples at once). I am not sure why keras model does not have batch size as an input. In the above code is it preferable to directly load the weights and all the values and do the manual computation to get the outputs? In the case of code 2, will each mode have a graph because it is called inside while loop?

Edit 2: function 3 has gradient computation of out3 with respect to inputs.

newbie
  • 443
  • 1
  • 4
  • 14
  • 1
    Do you need real-time usage of your model? Is `inputs` an array or tensor? When you run a model in a loop, a new graph is created with each iteration, and if you're using arrays as input, it creates copies of `inputs` with a new signature each time. Tensors avoid that issue. If you don't need real-time usage, load as much as you want into an array/tensor, then pass that collection to your model. – Djinn Aug 03 '22 at 00:53
  • 1
    I don't need real-time usage. The inputs are tensors. I thought that running a while loop won't make new graphs since the loops are independent and can be run in parallel. Passing everything as a collection to the model is not ideal because the model only processes a batch of 32 at once. I can not use .predict since I need to use .gradients in function2. Is there a way to increase the batch size internally of model() call? – newbie Aug 05 '22 at 22:29
  • 1
    "Passing everything as a collection to the model is not ideal because the model only processes a batch of 32 at once." I'm not sure I understand this. `model(x)` should process all of the input unless you memory doesn't allow all of the input data. `predict()`, on the other hand, works in batches. You can always use `predict()` then convert the output into a tensor - `tensor = tf.convert_to_tensor(array)`. – Djinn Aug 05 '22 at 22:41
  • 1
    [The documentation on predict](https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict) says that directly calling `__call__` (as you're doing) works on one batch, and you should also use small number of input. Since the model was defined with batch size 32, it will use 32. You need to use `predict()` then convert the output to tensors. – Djinn Aug 05 '22 at 22:48
  • 1
    Thanks for the answer; this is useful. Converting to tensor and then I can perform gradient calculation on it. Also, http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf says that while loop should not create new graphs. – newbie Aug 05 '22 at 22:50
  • I can not use predict() inside tf.function, I need to only use model() because model.predict is a high-level endpoint that has its own tf.function – newbie Aug 05 '22 at 23:06
  • I've also added an answer about the exploding memory thing too. – Djinn Aug 05 '22 at 23:06

1 Answers1

1

2 possible solutions:

According to the tensorflow documentation, your problem may be stored tensors used in back propogation.

Try:

Code 1
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  tf.while(swap_memory=True)
    out3 = inputs
    tf.while_loop(number_samples)
      model = tf.keras.models.load_model()
      out2 = model(out3)
      out3 = function2(out2)
    inputs = function3(out3)
  return out3
  
Code 2
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  model = tf.keras.models.load_model()
  tf.while(swap_memory=True)
    out3 = inputs
    out2 = model(out3)
    out3 = function2(out2)
    inputs = function3(out3)
  return out3

Or at the end of your loops, try calling, K.clear_session() to reset the states.

from tensorflow.keras import backend as K

Code 1
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  tf.while()
    out3 = inputs
    tf.while_loop(number_samples)
      model = tf.keras.models.load_model()
      out2 = model(out3)
      out3 = function2(out2)
    inputs = function3(out3)
  K.clear_session()
  return out3
  
Code 2
@tf.function(jit_compile=True)
def myfunction(inputs, model):
  model = tf.keras.models.load_model()
  tf.while()
    out3 = inputs
    out2 = model(out3)
    out3 = function2(out2)
    inputs = function3(out3)
  K.clear_session()
  return out3
Djinn
  • 663
  • 5
  • 12
  • 2
    I tried both of them, but none of them worked. – newbie Aug 05 '22 at 23:13
  • 1
    Somewhere in the loops, can you include this line of code: `print([x for x in tf.get_default_graph().get_operations()])` and see if the operations increase with each call? – Djinn Aug 06 '22 at 01:52
  • 1
    Does that work with tf.function? It got printed, but docs say that it does not work with tf.function – newbie Aug 06 '22 at 05:26
  • 1
    Not sure. When you say it got printed, does that also mean there were increasing operations with each call? Or something else printed? – Djinn Aug 06 '22 at 06:05