I want to calculate the Jacobian matrix and Hessian matrix of the neural network in pytorch.
I know you can use the vmap
and jacrev
functions under the func module, but it is much slower than the oracle function:
from torch.func import vmap, jacrev
import torch
import time
a = torch.rand(10000, 10000)
def f(x):
return (x ** 2).sum(-1)
def df(x):
return 2 * x
t0 = time.time()
b = df(a)
t1 = time.time()
c = vmap(jacrev(f))(a)
t2= time.time()
assert torch.allclose(b, c)
print(t1 - t0, t2 - t1)
result: 0.10568618774414062 0.9206998348236084
Given that oracle’s Jacobian is readily available in neural networks, I wonder why using jacrev
is so much slower? Is there something wrong with me?
Of course, I can actually rewrite each layer of the neural network to obtain the value and Jacobian at the same time, but calculating the Hessian matrix is too troublesome. It would be great if jacrev
could be faster.