It seems that the standard way to compute the gradient of the output of a keras model with respect to the input variables (for example, see How to compute gradient of output wrt input in Tensorflow 2.0) is something like the following:
with tf.GradientTape() as tape:
preds = model(input)
grads = tape.gradient(preds, input)
However, this is extremely slow when the input tensor is large (e.g. ten million observations of 500 input variables). The above code also does not seem to use the GPU at all.
When training the model using model.fit(input)
, it runs on the GPU and is super fast, despite the large input tensor.
Is there any way to speed up the gradient calculation?
About version
I am running Python 3.8 and Tensorflow v2.9.1. For various reasons I can only run in graph mode--i.e., tf.compat.v1.disable_eager_execution().
Thanks in advance!