1

I want to calculate the Jacobian matrix and Hessian matrix of the neural network in pytorch.

I know you can use the vmap and jacrev functions under the func module, but it is much slower than the oracle function:

from torch.func import vmap, jacrev
import torch
import time

a = torch.rand(10000, 10000)

def f(x):
    return (x ** 2).sum(-1)

def df(x):
    return 2 * x


t0 = time.time()
b = df(a)
t1 = time.time()
c = vmap(jacrev(f))(a)
t2= time.time()

assert torch.allclose(b, c)
print(t1 - t0, t2 - t1)

result: 0.10568618774414062 0.9206998348236084

Given that oracle’s Jacobian is readily available in neural networks, I wonder why using jacrev is so much slower? Is there something wrong with me?

Of course, I can actually rewrite each layer of the neural network to obtain the value and Jacobian at the same time, but calculating the Hessian matrix is too troublesome. It would be great if jacrev could be faster.

Frank Tian
  • 11
  • 1

0 Answers0