I ran into the exact same problem. Unfortunately, working with Pytrees in Jax can be awkward. I was also looking at a way to construct the diagonal Hessian entry-for-entry, since that could yield a practical method.
I now have the following:
def ravelled_diagonal_indices(dims: Sequence[int]) -> jnp.ndarray:
# Get the indices for the diagonal elements of a flattened square matrix.
return (dims[0] + 1) * jnp.arange(dims[0])
# Alias to reduce clutter.
_diag_idx = ravelled_diagonal_indices
def tree_matrix_diagonal(tree: Any, reference: Optional[Any] = None) -> Any:
"""Utility function for extracting the diagonal of a Pytree of jax.numpy.array objects.
The Pytree is assumed to be square in its children and in its array objects.
Parameters
----------
tree : Any
Pytree of jax.numpy.array objects for which the number of Pytree leaves and
the sizes of each constituent array is square.
reference : Any, default = None
The intended structure for the diagonal of `tree`. For example, this can be
the Pytree with which `tree` could have been created through e.g., an outer-product
or the Hessian of a function.
Returns
-------
diag : Any
Pytree containing the flattened diagonals of `tree` if no reference was provided.
Otherwise, the diagonal elements are shaped according to the structure of `reference`.
"""
flat = jax.tree_leaves(tree)
h = jax.numpy.sqrt(len(flat)).astype(int)
_idx = _diag_idx((h,))
block_diag = [flat[i] for i in _idx]
flat_diagonal = lambda w: w.ravel()[_diag_idx((jax.numpy.sqrt(w.size).astype(int),))]
diag = jax.tree_map(flat_diagonal, block_diag)
if reference is not None:
# Reshape the diagonal Pytree to reference Pytree structure and shape
diag_tree = jax.tree_unflatten(jax.tree_structure(reference), diag)
diag = jax.tree_multimap(lambda a, b: a.reshape(jax.numpy.shape(b)), diag_tree, reference)
return diag
When I try this out on the Hessian of a very simple MLP:
params
>> {'dense/~/affine': {'weights': DeviceArray([[ 1. , 1. ],
[ 0.546326 , -0.77997607]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[ 1. ],
[-0.5155028],
[ 0.9487318]], dtype=float32)}}
hessian
>> {'dense/~/affine': {'weights': {'dense/~/affine': {'weights': DeviceArray([[[[[-0.02324889, 0.04278728],
[ 0.00814307, -0.01498652]],
[[ 0.04278728, -0.07874574],
[-0.01498652, 0.0275812 ]]],
[[[ 0.00814307, -0.01498652],
[-0.00285216, 0.00524912]],
[[-0.01498652, 0.0275812 ],
[ 0.00524912, -0.00966049]]]]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[[[[ 0.04509945],
[ 0.15897979],
[ 0.05742025]],
[[-0.08300105],
[-0.06711845],
[ 0.01683405]]],
[[[-0.01579637],
[-0.05568369],
[-0.02011181]],
[[ 0.02907166],
[ 0.02350867],
[-0.00589623]]]]], dtype=float32)}}},
'dense_1/~/affine': {'weights': {'dense/~/affine': {'weights': DeviceArray([[[[[ 0.04509945, -0.08300105],
[-0.01579637, 0.02907165]]],
[[[ 0.15897979, -0.06711845],
[-0.0556837 , 0.02350867]]],
[[[ 0.05742024, 0.01683406],
[-0.02011181, -0.00589624]]]]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[[[[-0.08748633],
[-0.07074545],
[-0.11138687]]],
[[[-0.07074545],
[-0.05720801],
[-0.09007253]]],
[[[-0.11138687],
[-0.09007251],
[-0.14181684]]]]], dtype=float32)}}}}
Then, the function returns:
tree_matrix_diagonal(hessian, reference=params)
>> {'dense/~/affine': {'weights': DeviceArray([[-0.02324889, -0.07874574],
[-0.00285216, -0.00966049]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[-0.08748633],
[-0.05720801],
[-0.14181684]], dtype=float32)}}
Upon visual inspection, you can see that the returned elements are indeed the diagonal elements of hessian
cast to the canonical structure of params
.
Funnily enough, for the Gauss-Newton approximation to the Hessian the procedure is much simpler. Simply take the element-wise square of the Jacobians :).