Search code examples
pythonnumpymatrixlinear-algebranumerical-methods

Optimize this Python code that involves matrix inversion


I have this line of code that involves a matrix inversion:

X = A @ B @ np.linalg.pinv(S)

A is an n by n matrix, B is an n by m matrix, and S is an m by m matrix. m is smaller than n but usually not orders of magnitude smaller. Usually m is about half of n. S is a symmetrical positive definite matrix.

How do I make this line of code run faster in Python?

I can do

 X = np.linalg.solve(S.T, (A@B).T).T 

But I am also curious if I can take advantage of the fact that S is symmetrical.


Solution

  • So your problem is XS = AB = C. As you've stated, this can be rewritten as S'X' = B'A' = C'. C is of size m x n, but this batched problem can be solved using scipy.linalg.solve. In this case, I recommend the scipy alternative (rather than numpy) because you have stated that S is symmetric, so you can pass assume_a="sym" argument so that scipy selects a solver that takes advantage of the matrix structure.

    So, your code will look like this:

    X = scipy.linalg.solve(S.T, (A@B).T, assume_a="sym").T