3

Is there a way to calculate the determinant of a complex matrix in PyTroch?

torch.det is not implemented for 'ComplexFloat'

iacob
  • 20,084
  • 6
  • 92
  • 119
DeepRazi
  • 259
  • 4
  • 13

2 Answers2

1

Unfortunately it's not implemented currently. One way would be to implement your own version or simply use np.linalg.det. Here is a short function which computes the determinant of a complex matrix that I wrote using LU-decomposition:

def complex_det(A):
    def complex_diag(A):
        return torch.view_as_complex(torch.stack((A.real.diag(), A.imag.diag()),dim=1))
    #Perform LU decomposition to matrix A:
    A_LU, pivots = A.lu()
    P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
    #Det. of multiplied matrices is multiplcation of det.:
    det = torch.prod(complex_diag(A_L)) * torch.prod(complex_diag(A_U)) * torch.det(P.real) #Could probably calculate det(P) [which is +-1] efficiently using Sylvester's determinant identity
    return det
#Test it:
A = torch.view_as_complex(torch.randn(3,3,2))
complex_det(A)
Gil Pinsky
  • 2,388
  • 1
  • 12
  • 17
0

As of version 1.8, PyTorch has native support for numpy-style torch.linalg operations. In particular, torch.linalg.det has support for cfloat and cdouble complex number data-types:

torch.linalg.det(input)

Computes the determinant of a square matrix input, or of each square matrix in a batched input.

This function supports float, double, cfloat and cdouble dtypes.

iacob
  • 20,084
  • 6
  • 92
  • 119