Search code examples
pythonscipypytorchlinear-algebrabackpropagation

Solving Sylvester equations in PyTorch


I'm trying to solve a Sylvester matrix equation of the form

AX + XB = C

From what I've seen, these equations are usually solved with the Bartels-Stewart algorithm taking successive Schur decompositions. I'm aware scipy.linalg already has a solve_sylvester function, but I'm integrating the solution to the Sylvester equation into a neural network, so I need a way to calculate gradients to make A, B, and C learnable. Currently, I'm just solving a linear system with torch.linalg.solve using the Kronecker product and vectorization trick, but this has terrible runtime complexity. I haven't found any PyTorch support for Sylvester equations, let alone Schur decompositions, but before I try to implement Barters-Stewart on the GPU, is there a simpler way to find the gradients?


Solution

  • Initially I wrote a solution that would give complex X based on Bartels-Stewart algorithm for the m=n case. I had some problems because the eigenvector matrix is not accurate enough. Also the real part gives the real solution, and the imaginary part must be a solution for AX - XB = 0

    import torch
    
    def sylvester(A, B, C, X=None):
        m = B.shape[-1];
        n = A.shape[-1];
        R, U = torch.linalg.eig(A)
        S, V = torch.linalg.eig(B)
        F = torch.linalg.solve(U, (C + 0j) @ V)
        W = R[..., :, None] - S[..., None, :]
        Y = F / W
        X = U[...,:n,:n] @ Y[...,:n,:m] @ torch.linalg.inv(V)[...,:m,:m]
        return X.real if all(torch.isreal(x.flatten()[0]) 
                    for x in [A, B, C]) else X
    

    As can be verified on the GPU with

    device='cuda'
    # Try different dimensions
    for batch_size, M, N in [(1, 4, 4), (20, 16, 16), (6, 13, 17), (11, 29, 23)]:
        print(batch_size, (M, N))
        A = torch.randn((batch_size, N, N), dtype=torch.float64, 
                        device=device, requires_grad=True)
        B = torch.randn((batch_size, M, M), dtype=torch.float64, 
                        device=device, requires_grad=True)
        X = torch.randn((batch_size, N, M), dtype=torch.float64, 
                        device=device, requires_grad=True)
        C = A @ X - X @ B
        X_ = sylvester(A, B, C)
        C_ = (A) @ X_ - X_ @ (B)
        print(torch.max(abs(C - C_)))
        X.sum().backward()
    

    A faster algorithm, but inaccurate in the current pytorch version is

    def sylvester_of_the_future(A, B, C):
        def h(V):
            return V.transpose(-1,-2).conj()
        m = B.shape[-1];
        n = A.shape[-1];
        R, U = torch.linalg.eig(A)
        S, V = torch.linalg.eig(B)
        F = h(U) @ (C + 0j) @ V
        W = R[..., :, None] - S[..., None, :]
        Y = F / W
        X = U[...,:n,:n] @ Y[...,:n,:m] @ h(V)[...,:m,:m]
        return X.real if all(torch.isreal(x.flatten()[0]) for x in [A, B, C]) else X
    

    I will leave it here maybe in the future it will work properly.