I'm currently writing a code to project a large number of particles (10k in total) onto an MLS (Moving Least Squares) surface. This task requires calling the lstsq function 10k times, each matrices has the shape like (N,6)
(N is between 200 and 1k depending on the location of the particles).
Currently, I am using torch.linalg.lstsq
for the computation, which takes approximately half a second for each call. Therefore, I'm seeking a more efficient approach to accomplish this task.If there are any code examples or recommended libraries, that would be extremely helpful.
I've tried some methods:
numpy.linalg.lstsq
with a for-loop. It takes me about 1.3s.
scipy.linalg.lstsq
with gelsy
LAPACK driver and a for-loop. It also takes me about 1.3s.
SVD method. It takes me about 1.5s, and it looks like:
u, s, v = np.linalg.svd(A, full_matrices=False)
uTb = np.einsum('ijk,ij->ik', u, b)
c = np.einsum('ijk,ij->ik', v, uTb / s)
return c
Use np.linalg.solve
to solve A.T@Ax=A.T@b
. It takes me about 1.5s, and it looks like:
ATA = np.einsum('ijk,ijl->ikl', A, A)
ATb = np.einsum('ijk,ij->ik', A, b)
c = np.linalg.solve(ATA, ATb)
return c
Multithreading. Due to using the Taichi library in another part of my code, when attempting to use multithreading, multiple Taichi backends are initialized, causing my computer to freeze.
from multiprocessing import Pool
with Pool() as pool:
c = pool.starmap(calc_lstsq, [(A[i], b[i]) for i in range(b.shape[0])])
return np.asarray(c)
I recommend JAX or numba, using your normal equation example I obtain a speed up of ~10x and ~17x:
import numpy as np
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax
import numba as nb
rng = np.random.default_rng()
A = rng.random((10000, 1000, 6))
b = rng.random((10000, 1000))
def normal_np(A, b):
ATA = np.einsum('ijk,ijl->ikl', A, A)
ATb = np.einsum('ijk,ij->ik', A, b)
return np.linalg.solve(ATA, ATb)
@jax.jit
def normal_jax(A, b):
ATA = jnp.einsum('ijk,ijl->ikl', A, A)
ATb = jnp.einsum('ijk,ij->ik', A, b)
return jnp.linalg.solve(ATA, ATb)
@nb.njit(parallel=True)
def normal_nb(A, b):
assert A.shape[:-1] == b.shape
assert A.ndim == 3
output = np.zeros((A.shape[0], A.shape[2]))
for i in nb.prange(A.shape[0]):
ai = A[i]
aiT = ai.T
bi = b[i]
output[i] = np.linalg.solve(aiT @ ai, aiT @ bi)
return output
Timings:
assert np.allclose(normal_np(A, b), normal_jax(A, b))
assert np.allclose(normal_np(A, b), normal_nb(A, b))
%timeit normal_np(A, b)
%timeit normal_jax(A, b).block_until_ready()
%timeit normal_nb(A, b)
Output:
882 ms ± 6.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
81.5 ms ± 294 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
46.7 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)