1

I am relatively new to PyTorch and trying to compute the Hessian of a very simple feedforward networks with respect to its weights. I am trying to get torch.autograd.functional.hessian to work. I have been digging the forums and since this is a relatively new function added to PyTorch, I am unable to find a whole lot of information on it. Here is my simple network architecture which is from some sample code on Kaggle on Mnist.

class Network(nn.Module):
    
    def __init__(self):
        super(Network, self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l3 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l3(x)
        return F.log_softmax(x, dim = 1)
net = Network()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
loss_func = nn.CrossEntropyLoss()

and I am running the NN for a bunch of epochs like:

for e in range(epochs):
    for i in range(0, x.shape[0], batch_size):
        x_mini = x[i:i + batch_size] 
        y_mini = y[i:i + batch_size] 
        x_var = Variable(x_mini)
        y_var = Variable(y_mini)
        optimizer.zero_grad()
        net_out = net(x_var)
        loss = loss_func(net_out, y_var)
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            loss_log.append(loss.data)

Then, I add all the parameters to a list and make a tensor out of it as below:

param_list = []
for param in net.parameters():
    param_list.append(param.view(-1))
param_list = torch.cat(param_list)

Finally, I am trying to compute the Hessian of the converged network by running:

hessian = torch.autograd.functional.hessian(loss_func, param_list,create_graph=True)

but it gives me this error: TypeError: forward() missing 1 required positional argument: 'target'

Any help would be appreciated.

Kasra
  • 13
  • 2

1 Answers1

0

Computing the hessian with regard to the parameters of a model (as opposed to the inputs to the model) isn't really well-supported right now. There's some work being done on this at https://github.com/pytorch/pytorch/issues/49171 , but for the moment it's very inconvenient.

Your code has a few other problems -- where you're passing loss_func, you should be passing a function that constructs the computation graph. Also, you never specify the input to the network or the target for the loss function.

Here's some code that cheats a little bit to use the existing functional interface to compute the hessian of the model weights, and concatenates everything together to give the same form as what you were trying to do:

# Pick a random input to the network                             
src = torch.rand(1, 2)                                           
# Say our target for our loss is all ones                        
dst = torch.ones(1, dtype=torch.long)                            
                                                                 
keys = list(net.state_dict().keys())                             
parameters = list(net.parameters())                              
sizes = [x.view(-1).shape[0] for x in parameters]                
ndims = sum(sizes)                                               
                                                                 
def hessian_hack(*params):                                       
    for i in range(len(keys)):                                   
        path = keys[i].split('.')                                
        cur = net                                                
        for f in range(0, len(path)-1):                          
            cur = net.__getattr__(path[f])                       
        cur.__delattr__(path[-1])                                
        cur.__setattr__(path[-1], params[i])                     
    return loss_func(net(src), dst)                              
                                                                 
# sub_hessians[i][f] is the hessian of parameter i vs parameter f
sub_hessians = torch.autograd.functional.hessian(                
    hessian_hack,                                                
    tuple(parameters),                                           
    create_graph=True)                                           
                                                                 
# We can combine them all into a nice big hessian.               
hessian = torch.cat([                                            
        torch.cat([                                              
            sub_hessians[i][f].reshape(sizes[i], sizes[f])       
            for f in range(len(sub_hessians[i]))                 
        ], axis=1)                                               
    for i in range(len(sub_hessians))                            
], axis=0)                                                       
print(hessian)                                                   
mlucy
  • 5,249
  • 1
  • 17
  • 21
  • Note, I'm not 100% sure this code is correct, I just wrote it and it returned something plausible. It should be a starting point at least. – mlucy Sep 04 '21 at 01:39
  • This is not right. I verified it for a small network by comparing the top eigenvalues of the Hessian by this code and what I get from PyHessian and they do not match. I cannot figure out what is wrong in your code though. – Kasra Sep 21 '21 at 20:33
  • Could you please point out where I can learn more about computing the Hessian with respect to the input? You mention that this is already supported, but I find it very difficult to find any working example online. – shuhalo Feb 25 '23 at 14:06