0

Question:

Is there any working method to calculate gradient of (non-scalar) tensor function?

Example

Given n by n symmetric matrices X, Y and matrix function Z(X, Y) = torch.mm(X.mm(X), Y) calculate d(dZ/dX)/dY.

Expected answer

d(dZ/dX)/dY = d(2*XY)/dY = 2*X

Attempts

Because torch's .backward() works only for scalar variables I've tried to calculate derivative by applying torch.autograd.grad() to each element of tensor Z, but this approach is not correct, because it gives d(X^2)/dX = X + 2*D where D is a diagonal matrix with diagonal values of X. For me it's a bit weird that torch has an ability to build a computational graph, but can't track tensor through it as a variable to get tensor derivative.

Edit

Question was not very clear, so I decided to give more details.

My aim is to get partial derivative of loss function, which involves two matrices as variables. It looks like that:

loss = torch.linalg.norm(my_formula(X, Y) , ord='fro')

And I need to find

  1. d^2(loss)/d(Y^2)
  2. d/dX[d(loss)/dY]

Torch is capable of calculating 1. by using .backward() two times, but it's problematic to find 2. because torch.autograd.grad() expects scalar input and not the tensor

  • 1
    First I think you need to define what exactly you mean by your derivative of a matrix with respect to another matrix. But no matter what (conventional) definition you choose, I don't think your expected answer holds for any of them. Can you provide a proof of your expected answer? – flawr Aug 02 '22 at 10:31
  • @flawr I've tried to provide more information in the edit. As for the derivative, I think of it as one variable (not a table of variables) and use matrix cookbook in calculations. Correct me, if my approach is not right. – Elfat Sabitov Aug 04 '22 at 11:08
  • 1
    It is still unclear. You can't just treat them scalar variables if you're involving matrix mulitplications, unless you're maybe using some quite non-standard notion of derivative. Can you maybe clarify your question by providing an [MCVE] *with* some concrete inputs and expected outputs? – flawr Aug 04 '22 at 14:30

1 Answers1

0

TL;DR

For function f which takes a matrix and gives a scalar:

  1. Find first order derivative, let's name it dX
  2. Take trace: Tr(dX)
  3. To get mixed partial derivative just use the trace from above: d/dY[df/dX] = d/dY[Tr(df/dX)]

Intro

At the moment of posting the question I was not really that good at theory of matrix derivatives, but now I know much more all thanks to this Yandex ml book (unfortunately, I didn't find the english equivalent). This is an attempt to give a full answer to my question.

Basic Theory

Forgive me, Lord, for ugly representation of latex

Let's say you have a function which takes matrix X and returns it's squared Frobenius norm: f(X) = ||X||_F^2
It is a well-known fact that: ||X||_F^2 = Tr(X X^T)

Let's define derivative as shown in same book: D[f] at X_0 = f(X + H) - f(X)
We are ready to find dg(X)/dX:

df(X)/dX = dTr(X X^T)/dX =

(using Trace's feature)
= Tr(d/dX[X X^T]) = Tr(dX/dX X^T + X d[X^T]/dX ) =

(then we should use the definition of derivative from above)
= Tr(HX^T + XH^T) = Tr(HX^T) + Tr(XH^T) =

(now the main trick is to get all matrices H on the right side and get something like
Tr(g(X) H) or Tr(g(X) H^T), where g(X) will be the derivative we are looking for)
= Tr(HX^T) + Tr(XH^T) = Tr(XH^T) + Tr(XH^T) = Tr(2*XH^T)

That means: df(X)/dX = 2X

Second order derivative

Now, after we found out how to get matrix derivatives, let's try to find second order derivative of the same function f(X): d/dX[df(X)/dX] = d/dX[Tr(2XH_1^T)] = Tr(d/dX[2XH_1^T]) =
= Tr(2I H_2 H_1^T)

We found out that d/dX[df(X)/dX] = 2I where I stands for Identity matrix. But how will it help us to find derivatives in Pytorch?

Trace is the trick

As we can see from the formulas, both first and second order derivatives have Trace inside them, but when we take first order derivative we just instantly get matrix as a result. To get a higher order derivative we just need to take the derivative of trace of first order derivative:

d/dY[df/dX] = d/dY[Tr(df/dX)]

The thing is I was using JAX autograd library when this trick came to my mind, so the code with a function f(X,Y) will look like this:

def scalarized_dy(X, Y):
    dY = grad(f, argnums=1)(X, Y)
    return jnp.trace(dY)

dYX = grad(scalarized_dy, argnums=0)(X, Y)
dYY = grad(scalarized_dy, argnums=1)(X, Y)

In case of Pytorch I guess we will need to look after tensors' gradients (let loss be a function with X and Y as arguments):

loss = f(X, Y)
loss.backward(create_graph = True)
dX = torch.trace(X.grad)
dX.backward()
dXX = X.grad
dXY = Y.grad

Epilogue

I thought that the question itself is in some way interesting. Also, it took me several months to figure things out, so I decided to give my current point of view on this problem. I will not mark my answer as correct yet in hope that I will get some kind of feedback or, perhaps, even better answers or ideas.