Search code examples
pythonpytorchcomplex-numbersdeterminants

Determinant of a complex matrix in PyTorch


Is there a way to calculate the determinant of a complex matrix in PyTroch?

torch.det is not implemented for 'ComplexFloat'


Solution

  • Unfortunately it's not implemented currently. One way would be to implement your own version or simply use np.linalg.det. Here is a short function which computes the determinant of a complex matrix that I wrote using LU-decomposition:

    def complex_det(A):
        def complex_diag(A):
            return torch.view_as_complex(torch.stack((A.real.diag(), A.imag.diag()),dim=1))
        #Perform LU decomposition to matrix A:
        A_LU, pivots = A.lu()
        P, A_L, A_U = torch.lu_unpack(A_LU, pivots)
        #Det. of multiplied matrices is multiplcation of det.:
        det = torch.prod(complex_diag(A_L)) * torch.prod(complex_diag(A_U)) * torch.det(P.real) #Could probably calculate det(P) [which is +-1] efficiently using Sylvester's determinant identity
        return det
    #Test it:
    A = torch.view_as_complex(torch.randn(3,3,2))
    complex_det(A)