I have this line of code that involves a matrix inversion:
X = A @ B @ np.linalg.pinv(S)
A is an n by n matrix, B is an n by m matrix, and S is an m by m matrix. m is smaller than n but usually not orders of magnitude smaller. Usually m is about half of n. S is a symmetrical positive definite matrix.
How do I make this line of code run faster in Python?
I can do
X = np.linalg.solve(S.T, (A@B).T).T
But I am also curious if I can take advantage of the fact that S is symmetrical.
So your problem is XS = AB = C
. As you've stated, this can be rewritten as S'X' = B'A' = C'
. C
is of size m x n
, but this batched problem can be solved using scipy.linalg.solve
. In this case, I recommend the scipy alternative (rather than numpy) because you have stated that S
is symmetric, so you can pass assume_a="sym"
argument so that scipy selects a solver that takes advantage of the matrix structure.
So, your code will look like this:
X = scipy.linalg.solve(S.T, (A@B).T, assume_a="sym").T