0

It seems that the standard way to compute the gradient of the output of a keras model with respect to the input variables (for example, see How to compute gradient of output wrt input in Tensorflow 2.0) is something like the following:

with tf.GradientTape() as tape:
    preds = model(input)    
grads = tape.gradient(preds, input)

However, this is extremely slow when the input tensor is large (e.g. ten million observations of 500 input variables). The above code also does not seem to use the GPU at all.

When training the model using model.fit(input), it runs on the GPU and is super fast, despite the large input tensor.

Is there any way to speed up the gradient calculation?

About version

I am running Python 3.8 and Tensorflow v2.9.1. For various reasons I can only run in graph mode--i.e., tf.compat.v1.disable_eager_execution().

Thanks in advance!

42bsk
  • 76
  • 1
  • 10
  • model.fit is fast due to batching, you will clog any device with such a large tensor. – Dr. Snoopy Aug 22 '22 at 17:20
  • Thanks. How should I implement efficient batching for the gradient calculation? Using loops in python is also extremely slow. – 42bsk Aug 22 '22 at 17:31

1 Answers1

4

The problem is that you are not handling batches, or at least this is what I understand from the info you have given.

According to the fit() documentation, the function takes an argument batch_size, which defaults to 32:

batch_size: Integer or None. Number of samples per gradient update. If unspecified, batch_size will default to 32

However with gradient tape, you have to manually handle batches. The input, in your code must be a batch.

This means you should have something like the following code:

for epoch in range(epochs):
    # Iterate over the batches of the dataset.
    for batch in range(num_batch):
        images = x_train[batch * batch_size: (batch + 1) * batch_size]
        labels = y_train[batch * batch_size: (batch + 1) * batch_size]
        # calling the tape on single batch
        step(model, images)

@tf.function
def step(model, x):
    with tf.GradientTape() as tape:
        preds = model(x)    
    grads = tape.gradient(preds, x)

Also in order to improve performances, I've wrapped the gradient tape inside a tf.function. This decorator is basically responsible, on first call, for compiling a static graph of the operations inside the function that it decorates. This way subsequent calls can be a lot faster. Here to know more about better performance with tf.function.

ClaudiaR
  • 3,108
  • 2
  • 13
  • 27
  • 1
    Thanks for the reply. I am indeed not handling batches in the gradient computation. But wouldn't the difference be that the batches in model.fit run in parallel while a loop in python would not? Or does tf.function take care of that? – 42bsk Aug 22 '22 at 16:44
  • 2
    As far as I know not even the batches in `model.fit` are handled in parallel: the data inside a batch is. If you think about it, it would mean that all your train data would fit into memory, which is almost always not possible, and would make the purpose of batches pointless. Your tape is slow because your CPU is overloaded with data. Also during the training phase are computed and stored the labels-predictions errors. Storing the errors for your whole dataset at each epoch is quite inefficient. Hope this clarifies. Also I've added some info about `tf.function` inside the answer. – ClaudiaR Aug 22 '22 at 17:51
  • just a note: the fact that something runs on CPU or GPU is based on the policy of TF, not on "graph or not", his loop is slow because as ClaudiaR explained, you are not taking advantage of TF Graph, which is built the first time, and then reused to optimize the process... just pay attention that is build for every input shape/type, so be carefull, otherwise the profiling will just slow down you training loop (in other words, graph are used to "remember" how to calculate the gradient, instead of discovering it again every time) – Alberto Sinigaglia Aug 22 '22 at 20:31
  • 1
    Thanks ClaudiaR, I will give your suggestion with tf.function a try. – 42bsk Aug 22 '22 at 22:14
  • Alright, let me know how it goes. If this solves your issue please consider accepting the answer and / or upvoting – ClaudiaR Aug 23 '22 at 06:15