MATLAB has the backslash "\" operator. SciPy has "lsqr." Does PyTorch have an equivalent operator that solves systems of linear equations?
Specifically, I need to solve the matrix equation for A*X=B
for A
, and I need autograd to be able to backpropagate error through the operation.
There is no \
operator in Python. The closest you will get to is Scipy's implementation: scipy.sparse.linalg.lsqr
.
You can either use
torch.solve
to solve linear equations of shape AX=B
torch.lsrsq
to
solve the least-squares problem min ||AX-B||_2
(if A.size(0) >= A.size(1)
) or
solve the least-norm problem min ||X||_2
such that AX=B
(if A.size(0) < A.size(1)
)
For solving XA=B
you would use the transposed matrices:
def lsqrt(A, B):
XT, _ = torch.solve(B.T, A.T)
return XT.T