5

How would you calculate a hessian of a loss function that consists of a Neural Network w.r.t. the NN's parameters?

For instance, consider the loss function below

using Flux: Chain, Dense, σ, crossentropy, params
using Zygote
model = Chain(
    x -> reshape(x, :, size(x, 4)),
    Dense(2, 5),
    Dense(5, 1),
    x -> σ.(x)
)
n_data = 5
input = randn(2, 1, 1, n_data)
target = randn(1, n_data)
loss = model -> crossentropy(model(input), target)

I can get a gradient w.r.t parameters in two ways…

Zygote.gradient(model -> loss(model), model)

or

grad = Zygote.gradient(() -> loss(model), params(model))
grad[params(model)[1]]

However, I can’t find a way to get a hessian w.r.t its parameters. (I want to do something like Zygote.hessian(model -> loss(model), model), but Zygote.hessian does not take ::Params as an input)

Recently, a jacobian function was added to the master branch (issue #910), which understands ::Params as an input.

I've been trying to combine gradient and jacobian to get a hessian (because a hessian is the jacobian of a gradient of a function), but to no avail. I think the problem is that model is a Chain object that includes generic functions like reshape and σ. which lack parameters, but I can't get past this.

grad = model -> Zygote.gradient(model -> loss(model), model)
jacob = model -> Zygote.jacobian(grad, model)
jacob(model) ## does not work

EDIT: For reference, I've created this in pytorch before

logankilpatrick
  • 13,148
  • 7
  • 44
  • 125
Miss Swiss
  • 89
  • 1
  • 9

1 Answers1

0

Not sure if this will help in your particular use-case, but you could work with an approximation of the Hessian, e.g. empirical Fisher (EF). I've worked with this approach to implement Laplace approximation for Flux models (see here) inspired by this PyTorch implementation. Below I've applied the approach to your example.

using Flux: Chain, Dense, σ, crossentropy, params, DataLoader
using Zygote
using Random

Random.seed!(2022)
model = Chain(
    x -> reshape(x, :, size(x, 4)),
    Dense(2, 5),
    Dense(5, 1),
    x -> σ.(x)
)
n_data = 5
input = randn(2, 1, 1, n_data)
target = randn(1, n_data)
loss(x, y) = crossentropy(model(x), y)

n_params = length(reduce(vcat, [vec(θ) for θ ∈ params(model)]))
 = zeros(n_params,n_params)
data = DataLoader((input, target))

for d in data
  x, y = d
   = gradient(() -> loss(x,y),params(model))  
   = reduce(vcat,[vec([θ]) for θ ∈ params(model)])
   +=  * ' # empirical fisher
end

Should there be a way to use Zygote autodiff directly (and more efficiently) I'd also be interested to see that. Using EF for the full Hessian still scales quadratically in the number of parameters, but as shown in this NeurIPS 2021 paper you can further approximate the Hessian using (blog-)diagonal factorization. The paper also shows that in the context of Bayesian deep learning treating only the last layer probabilistically generally yields good results, but again not sure if relevant in your case.

patalt
  • 465
  • 3
  • 10