I need to take a Hessian vector product of a loss w.r.t. model parameters a large number of times. It seems that there is no efficient way to do this and a for loop is always required, resulting in a large number of independent autograd.grad
calls. My current implementation is given below, it is representative of my use case. Do note in the real case the collection of vectors v
are not all known beforehand.
import torch
import time
# Create model
model = torch.nn.Sequential(torch.nn.Linear(1, 500), torch.nn.Tanh(), torch.nn.Linear(500, 1))
num_param = sum(p.numel() for p in model.parameters())
# Evaluate some loss on a random dataset
x = torch.rand((10000,1))
y = torch.rand((10000,1))
y_hat = model(x)
loss = ((y_hat - y)**2).mean()
# Calculate Jacobian of loss w.r.t. model parameters
J = torch.autograd.grad(loss, list(model.parameters()), create_graph=True)
J = torch.cat([e.flatten() for e in J]) # flatten
# Calculate Hessian vector product
start_time = time.time()
for i in range(10):
v = torch.rand(num_param)
HVP = torch.autograd.grad(J, list(model.parameters()), v, retain_graph=True)
print('Time per HVP: ', (time.time() - start_time)/10)
Which takes around 0.05 s per Hessian vector product on my machine. Is there a way to speed this up? Especially considering that the Hessian itself does not change in between calls.