Search code examples
pythonpytorchlinear-regression

Is there a function in PyTorch for matrix left division?


MATLAB has the backslash "\" operator. SciPy has "lsqr." Does PyTorch have an equivalent operator that solves systems of linear equations?

Specifically, I need to solve the matrix equation for A*X=B for A, and I need autograd to be able to backpropagate error through the operation.


Solution

  • There is no \ operator in Python. The closest you will get to is Scipy's implementation: scipy.sparse.linalg.lsqr.

    You can either use

    • torch.solve to solve linear equations of shape AX=B

    • torch.lsrsq to

      • solve the least-squares problem min ||AX-B||_2 (if A.size(0) >= A.size(1)) or

      • solve the least-norm problem min ||X||_2 such that AX=B (if A.size(0) < A.size(1)) ​


    For solving XA=B you would use the transposed matrices:

    def lsqrt(A, B):
        XT, _ = torch.solve(B.T, A.T)
        return XT.T