TL;DR
For function f which takes a matrix and gives a scalar:
- Find first order derivative, let's name it dX
- Take trace: Tr(dX)
- To get mixed partial derivative just use the trace from above: d/dY[df/dX] = d/dY[Tr(df/dX)]
Intro
At the moment of posting the question I was not really that good at theory of matrix derivatives, but now I know much more all thanks to this Yandex ml book (unfortunately, I didn't find the english equivalent). This is an attempt to give a full answer to my question.
Basic Theory
Forgive me, Lord, for ugly representation of latex
Let's say you have a function which takes matrix X and returns it's squared Frobenius norm: f(X) = ||X||_F^2
It is a well-known fact that: ||X||_F^2 = Tr(X X^T)
Let's define derivative as shown in same book: D[f] at X_0 = f(X + H) - f(X)
We are ready to find dg(X)/dX:
df(X)/dX = dTr(X X^T)/dX =
(using Trace's feature)
= Tr(d/dX[X X^T]) = Tr(dX/dX X^T + X d[X^T]/dX ) =
(then we should use the definition of derivative from above)
= Tr(HX^T + XH^T) = Tr(HX^T) + Tr(XH^T) =
(now the main trick is to get all matrices H on the right side and get something like
Tr(g(X) H) or Tr(g(X) H^T), where g(X) will be the derivative we are looking for)
= Tr(HX^T) + Tr(XH^T) = Tr(XH^T) + Tr(XH^T) = Tr(2*XH^T)
That means: df(X)/dX = 2X
Second order derivative
Now, after we found out how to get matrix derivatives, let's try to find second order derivative of the same function f(X):
d/dX[df(X)/dX] = d/dX[Tr(2XH_1^T)] = Tr(d/dX[2XH_1^T]) =
= Tr(2I H_2 H_1^T)
We found out that d/dX[df(X)/dX] = 2I where I stands for Identity matrix. But how will it help us to find derivatives in Pytorch?
Trace is the trick
As we can see from the formulas, both first and second order derivatives have Trace inside them, but when we take first order derivative we just instantly get matrix as a result. To get a higher order derivative we just need to take the derivative of trace of first order derivative:
d/dY[df/dX] = d/dY[Tr(df/dX)]
The thing is I was using JAX autograd library when this trick came to my mind, so the code with a function f(X,Y) will look like this:
def scalarized_dy(X, Y):
dY = grad(f, argnums=1)(X, Y)
return jnp.trace(dY)
dYX = grad(scalarized_dy, argnums=0)(X, Y)
dYY = grad(scalarized_dy, argnums=1)(X, Y)
In case of Pytorch I guess we will need to look after tensors' gradients (let loss be a function with X and Y as arguments):
loss = f(X, Y)
loss.backward(create_graph = True)
dX = torch.trace(X.grad)
dX.backward()
dXX = X.grad
dXY = Y.grad
Epilogue
I thought that the question itself is in some way interesting. Also, it took me several months to figure things out, so I decided to give my current point of view on this problem. I will not mark my answer as correct yet in hope that I will get some kind of feedback or, perhaps, even better answers or ideas.