3

Somewhere in my loss function, I invert a complex matrix of size 64*64. Although complex matrix inversion is supported for torch.tensor, the gradient cannot be computed in the training loop as I get this error:

RuntimeError: inverse does not support automatic differentiation for outputs with complex type.

Does anyone have a workaround for this issue? a custom function instead of torch.inverse maybe?

iacob
  • 20,084
  • 6
  • 92
  • 119
DeepFara
  • 31
  • 3

2 Answers2

3

You can do the inverse yourself using the real-valued components of your complex matrix.

Some linear algebra first:

a complex matrix C can be written as a sum of two real matrices A and B (j is the sqrt of -1):

C = A + jB  

Finding the inverse of C is basically finding two real valued matrices x and y such that

(A + jB)(x + jy) = I + j0

This boils down to solving the real valued system of equations: enter image description here

Now that we know how to do reduce a complex matrix inversion to real-valued matrix inversion, we can use pytorch's solve to do the inverse for us.

def complex_inverse(C):
  A = torch.real(C)
  B = torch.imag(C)
  # construct the left hand side of the system of equations
  # side note: from pytorch 1.7.1 you can use vstack and hstack instead of cat
  lhs = torch.cat([torch.cat([A, -B], dim=1), torch.cat([B, A], dim=1)], dim=0)
  # construct the rhs of the system of equations
  rhs = torch.cat([torch.eye(A.shape[0]).to(A), torch.zeros_like(A)],dim=0)

  # solve the system of equations
  raw, _ = torch.solve(rhs, lhs)
  # write the solution as a single complex matrix
  iC = raw[:C.shape[0], :] + 1j * raw[C.shape[0]:, :]
  return iC  

You can verify the solution using numpy:

# C is a complex torch tensor
iC = complex_inverse(C)

with torch.no_grad():  
  print(np.isclose(iC.cpu().numpy() @ C.cpu().numpy(), np.eye(C.shape[0])).all())  

Note that by using inverse of block-matrices tricks you may reduce the computational cost of the solve operation.

Shai
  • 111,146
  • 38
  • 238
  • 371
0

As of 1.9, PyTorch now supports complex autograd.

iacob
  • 20,084
  • 6
  • 92
  • 119