I am trying to use tf.test.compute_gradient
to check if tf
can properly differentiate my (rather complicated) custom loss function. However, tf.test.compute_gradient
shows unexpected behavior w.r.t the size of the perturbation delta
(dx
hereafter).
Theory
For a 1D function f(x)
, the numerical Jacobian of tf.test.compute_gradient
being computed as (f(x + dx) - f(x - dx)) / (2 dx)
, we expect the numerical jacobian to converge to the true value as E(dx) = f'''(x) dx^2 / 6 -> 0
with dx -> 0
. Though I haven't made the calculus for higher dimensions, I expect the behavior to be roughly the same.
Problem
However it seems that E(dx) -> + inf
with dx -> 0
, even for a very simple function.
If the numerical Jacobian is not converging when it should, how to test the gradients of custom functions, that might have errors in the way tf
automatically computes the gradient? How to differentiate rounding errors from errors in the automatically computed tf
gradient? Which value for delta
should I choose? Does it change with then number of dimension of the problem (i.e. if there is a tf.reduce_sum
and errors accumulate).
Test
In the code example below, I made the test for the log of a Normal law. As f'''(x)
is theoretically null, the numerical jacobian should be exact, yet, max(Jth - Jnu)
increases as delta
decreases.
The problem occurs with dtype=float32
and dtype=float64
, for any number of dimensions n
.
Although the error is minimal for n = 1
and eval_jac_at = 1
, it reaches non-negligible values for stiffer problems (eval_jac_at
>> 1) or highly-dimensional problems (n
>> 1)
import tensorflow as tf
import tensorflow.math as tfm
from tensorflow_probability import distributions as tfd
import numpy as np
import matplotlib.pyplot as plt
"""
Comparing theoretical and numerical gradients for a simple function
using tf.test.compute_gradient as a function of the perturbation delta
The studied function is the log of multivariate normal law:
log(N(loc, scale)) = \\sum_i (x_i - loc_i)**2 / (2 * scale_i**2) + cst
"""
# number of dimensions
n = 1
# data type
dtype = "float32"
# evaluate jacobian at (in every dimension)
eval_jac_at = 1
loc = tf.cast(0.0, dtype)
# .5 < scale < 1.5
scale = tf.cast(0.5 + np.random.rand(n), dtype)
norm_dist = tfd.Normal(loc=loc, scale=scale)
def f(x):
"""
Explicit definition
"""
return tfm.reduce_sum(-tfm.pow(x - loc, 2) / (2 * scale ** 2))
def g(x):
"""
Using tfp.distributions
"""
return tfm.reduce_sum(norm_dist.log_prob(x))
def compute_err_gradient(f, delta):
"""
Returns max(|Jnu - Jth|) for a function f and perturbation delta.
Jacobians evaluated at eval_jac_at * [1, ..., 1].
"""
Jth, Jnu = tf.test.compute_gradient(f, [eval_jac_at * tf.ones(n, dtype)], delta)
return np.max(np.abs(Jnu[0] - Jth[0]))
# Compute a series of delta = 1 / 2**i
deltas = 1 / np.power(2, np.arange(5, 14))
# Compute error on Jacobian for each value of delta
err_f = np.array([compute_err_gradient(f, delta) for delta in deltas])
err_g = np.array([compute_err_gradient(g, delta) for delta in deltas])
# Print and plot results
print()
print(f"{deltas=}")
print(f"{err_f=}")
print(f"{err_g=}")
print()
fig, ax = plt.subplots()
ax.loglog(deltas, err_f, label="f")
ax.loglog(deltas, err_g, label="g")
ax.set_xlabel("delta")
ax.set_ylabel("Jth - Jnu")
ax.legend()
ax.grid()
plt.show()
Output:
deltas=array([0.03125 , 0.015625 , 0.0078125 , 0.00390625, 0.00195312,
0.00097656, 0.00048828, 0.00024414, 0.00012207])
err_f=array([1.19209290e-07, 1.19209290e-07, 1.19209290e-07, 1.19209290e-07,
3.69548798e-06, 3.93390656e-06, 1.91926956e-05, 1.13248825e-05,
7.23600388e-05], dtype=float32)
err_g=array([1.19209290e-07, 3.69548798e-06, 3.93390656e-06, 3.93390656e-06,
1.13248825e-05, 1.13248825e-05, 4.97102737e-05, 4.97102737e-05,
4.97102737e-05], dtype=float32)