3

I am trying to update the weight each epoch, but I am processing the data in batches. The problem is, to normalize the loss, I need to tape TensorFlow variables outside the training loop(to be tracked and normalized). But when I do this, the training time is Huge.

I think, it accumulates variables from all batches into the graph and calculates gradients at the end.

I have started tracking variables outside the for loop and inside the for loop and the later is faster than first. I am confused about why this happens because whatever I do, my model's trainable variables and loss remain the same.

# Very Slow

loss_value = 0
batches = 0

with tf.GradientTape() as tape:
    for inputs, min_seq in zip(dataset, minutes_sequence):
        temp_loss_value = my_loss_function(inputs, min_seq)
        batches +=1
        loss_value = loss_value + temp_loss_value

# The following line takes huge time.
grads = tape.gradient(loss_value, model.trainable_variables)

# Very Fast

loss_value = 0
batches = 0

for inputs, min_seq in zip(dataset, minutes_sequence):
    with tf.GradientTape() as tape:
        temp_loss_value = my_loss_function(inputs, min_seq)
        batches +=1
        loss_value = loss_value + temp_loss_value

# If I do the following line, the graph will break because this are out of tape's scope.
    loss_value = loss_value / batches

# the following line takes huge time
grads = tape.gradient(loss_value, model.trainable_variables)

When I declare tf.GradientTape() inside the for loop, it is very fast but I outside It is slow.

P.S. - This is for a custom loss and the architecture contains just one hidden layer of size 10.

I want to know, the difference tf.GradientTape()'s position makes and how it should be used for per epoch weights updating in batched dataset.

gauravtolani
  • 130
  • 8

1 Answers1

5

The tape variable is used primarily to watch trainable tensor variables(record the previous and changing values of the variables) so that we can calculate the gradient for an epoch of training as per the loss function. It is an implementation of the python context manager construct used here to record the state of the variables. An excellent resource on python context managers is here. So if inside the loop it will record the variables (weights) for that forward pass so that the we can calculate the gradient for all those variables in one shot (instead of stack based gradient passing as in a naive implementation without a library like tensorflow). If it is outside the loop it will record the states for all the epochs and as per the Tensorflow source code it also flushes if using TF2.0 unlike TF1.x where model developer had to take care of flushing. In your example you do not have any writer set but if any writer is set it will do that too. So for the above code it will keep recording (Graph.add_to_collection method is used internally) all the weights and as epochs increase you should see slowdown. The rate of slowdown will be proportional on the size of the network(trainable variables) and the current epoch number.

So placing it inside the loop is correct. Also the gradients should be applied inside the for loop and not outside (at the same indent level as with) else you are only applying gradients at the end of your training loop (after last epoch). I see that your training may not be that good with the current placement of gradient retrieval(after which it is applied in your code though you omitted it in the snippet).

One more good resource on gradienttape I just found.

Sunny
  • 136
  • 8