3

In JAX's Quickstart tutorial I found that the Hessian matrix can be computed efficiently for a differentiable function fun using the following lines of code:

from jax import jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

However, one can compute the Hessian also by computing the following:

def hessian(fun):
  return jit(jacrev(jacfwd(fun)))

def hessian(fun):
  return jit(jacfwd(jacfwd(fun)))

def hessian(fun):
  return jit(jacrev(jacrev(fun)))

Here is a minimal working example:

import jax.numpy as jnp
from jax import jit
from jax import jacfwd, jacrev

def comp_hessian():

    x = jnp.arange(1.0, 4.0)

    def sum_logistics(x):
        return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

    def hessian_1(fun):
        return jit(jacfwd(jacrev(fun)))

    def hessian_2(fun):
        return jit(jacrev(jacfwd(fun)))

    def hessian_3(fun):
        return jit(jacrev(jacrev(fun)))

    def hessian_4(fun):
        return jit(jacfwd(jacfwd(fun)))

    hessian_fn = hessian_1(sum_logistics)
    print(hessian_fn(x))

    hessian_fn = hessian_2(sum_logistics)
    print(hessian_fn(x))

    hessian_fn = hessian_3(sum_logistics)
    print(hessian_fn(x))

    hessian_fn = hessian_4(sum_logistics)
    print(hessian_fn(x))


def main():
    comp_hessian()


if __name__ == "__main__":
    main()

I would like to know which approach is best to use and when? I also would like to know if it is possible to use grad() to compute the Hessian? And how does grad() differ from jacfwd and jacrev?

Gilfoyle
  • 3,282
  • 3
  • 47
  • 83

1 Answers1

4

The answer to your question is within the JAX documentation; see for example this section: https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev

To quote its discussion of jacrev and jacfwd:

These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd probably has an edge over jacrev.

and further down,

To implement hessian, we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function :ℝⁿ→ℝ), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since ∇:ℝⁿ→ℝⁿ), which is where forward-mode wins out.

Since your function looks like :ℝⁿ→ℝ, then jit(jacfwd(jacrev(fun))) is likely the most efficient approach.

As for why you can't implement a hessian with grad, this is because grad is only designed for derivatives of functions with scalar outputs. A hessian by definition is a composition of vector-valued jacobians, not a composition of scalar gradients.

jakevdp
  • 77,104
  • 11
  • 125
  • 160