I have a a loop in which I'm calculating several pseudoinverses of rather large, non-sparse matrices (eg. 20000x800
).
As my code spends most time on the pinv
, I was trying to find a way to speed up the computation. I'm already using multiprocessing (joblib/loky
) to run with several processes, but that of course increases also overhead. Using jit
did not help much.
Is there a faster way / better implementation to compute pseudoinverse using any function? Precision isn't key.
My current benchmark
import time
import numba
import numpy as np
from numpy.linalg import pinv as np_pinv
from scipy.linalg import pinv as scipy_pinv
from scipy.linalg import pinv2 as scipy_pinv2
@numba.njit
def np_jit_pinv(A):
return np_pinv(A)
matrix = np.random.rand(20000, 800)
for pinv in [np_pinv, scipy_pinv, scipy_pinv2, np_jit_pinv]:
start = time.time()
pinv(matrix)
print(f'{pinv.__module__ +"."+pinv.__name__} took {time.time()-start:.3f}')
numpy.linalg.pinv took 2.774
scipy.linalg.basic.pinv took 1.906
scipy.linalg.basic.pinv2 took 1.682
__main__.np_jit_pinv took 2.446
EDIT: JAX seems to be 30% faster! impressive! Thanks for letting me know @yuri-brigance . For Windows it works well under WSL.
numpy.linalg.pinv took 2.774
scipy.linalg.basic.pinv took 1.906
scipy.linalg.basic.pinv2 took 1.682
__main__.np_jit_pinv took 2.446
jax._src.numpy.linalg.pinv took 0.995
Try with JAX:
import jax.numpy as jnp
jnp.linalg.pinv(A)
Seems to be slightly faster than regular numpy.linalg.pinv
. On my machine your benchmark looks like this:
jax._src.numpy.linalg.pinv took 3.127
numpy.linalg.pinv took 4.284