Consider implementing a custom loss function that requires a temporary variable instantiation. If we need to implement custom gradients, TF expects there to be an additional output of the gradient functional, when there should only be as many components of the gradient as there is inputs of the loss function. That is, if my understanding is correct. Any corrections are appreciated.
Linking related github issue, which contains a minimal working example (MWE) and additional debugging information: https://github.com/tensorflow/tensorflow/issues/31945
The MWE here, copy-pasted from the github post is:
import tensorflow as tf
# from custom_gradient import custom_gradient # my corrected version
from tensorflow import custom_gradient
def layer(t, name):
var = tf.Variable(1.0, dtype=tf.float32, use_resource=True, name=name)
return t * var
@custom_gradient
def custom_gradient_layer(t):
result = layer(t, name='outside')
def grad(*grad_ys, variables=None):
assert variables is not None
print(variables)
grads = tf.gradients(
layer(t, name='inside'),
[t, *variables],
grad_ys=grad_ys,
)
grads = (grads[:1], grads[1:])
return grads
return result, grad
Which will throw ValueError: not enough values to unpack...
.
If my understanding is correct, usually for the adjoint method (reverse mode autodiff), the forward pass builds the expression tree, and for the reverse pass we evaluate the gradients, and the gradient functional is value times the partial derivative of the function we'd take the derivative with respect to, which could be a composite function. I can post a reference if needed.
So, with one input variable, we'd have one evaluation of the gradient. Here, TF expects 2, even though we only have one input variable, because of the temp variable, which is unavoidable in some cases.
My MWE pseudo code is something like this:
@tf.custom_gradient
def custom_loss(in):
temp = tf.Variable(tf.zeros([2 * N - 1]), dtype = tf.float32)
## compute loss function
...
def grad(df):
grad = df * partial_derivative
return grad
return loss, grad
Andre