Search code examples
pythonnumpypytorchlinear-regressionlinear-algebra

How to efficiently calculate lstsq 10k times in Python?


I'm currently writing a code to project a large number of particles (10k in total) onto an MLS (Moving Least Squares) surface. This task requires calling the lstsq function 10k times, each matrices has the shape like (N,6) (N is between 200 and 1k depending on the location of the particles).

Currently, I am using torch.linalg.lstsq for the computation, which takes approximately half a second for each call. Therefore, I'm seeking a more efficient approach to accomplish this task.If there are any code examples or recommended libraries, that would be extremely helpful.

I've tried some methods:

  1. numpy.linalg.lstsq with a for-loop. It takes me about 1.3s.

  2. scipy.linalg.lstsq with gelsy LAPACK driver and a for-loop. It also takes me about 1.3s.

  3. SVD method. It takes me about 1.5s, and it looks like:

    u, s, v = np.linalg.svd(A, full_matrices=False)
    uTb = np.einsum('ijk,ij->ik', u, b)
    c = np.einsum('ijk,ij->ik', v, uTb / s)
    return c
    
  4. Use np.linalg.solve to solve A.T@Ax=A.T@b. It takes me about 1.5s, and it looks like:

    ATA = np.einsum('ijk,ijl->ikl', A, A)
    ATb = np.einsum('ijk,ij->ik', A, b)
    c = np.linalg.solve(ATA, ATb)
    return c
    
  5. Multithreading. Due to using the Taichi library in another part of my code, when attempting to use multithreading, multiple Taichi backends are initialized, causing my computer to freeze.

    from multiprocessing import Pool
    with Pool() as pool:
          c = pool.starmap(calc_lstsq, [(A[i], b[i]) for i in range(b.shape[0])])
    return np.asarray(c)
    

Solution

  • I recommend JAX or numba, using your normal equation example I obtain a speed up of ~10x and ~17x:

    import numpy as np
    from jax import config
    config.update("jax_enable_x64", True)
    import jax.numpy as jnp
    import jax
    import numba as nb
    
    rng = np.random.default_rng()
    A = rng.random((10000, 1000, 6))
    b = rng.random((10000, 1000))
    
    def normal_np(A, b):
        ATA = np.einsum('ijk,ijl->ikl', A, A)
        ATb = np.einsum('ijk,ij->ik', A, b)
        return np.linalg.solve(ATA, ATb)
    
    @jax.jit
    def normal_jax(A, b):
        ATA = jnp.einsum('ijk,ijl->ikl', A, A)
        ATb = jnp.einsum('ijk,ij->ik', A, b)
        return jnp.linalg.solve(ATA, ATb)
    
    @nb.njit(parallel=True)
    def normal_nb(A, b):
        assert A.shape[:-1] == b.shape
        assert A.ndim == 3
        output = np.zeros((A.shape[0], A.shape[2]))
        for i in nb.prange(A.shape[0]):
            ai = A[i]
            aiT = ai.T
            bi = b[i]
            output[i] = np.linalg.solve(aiT @ ai, aiT @ bi)
        return output
    

    Timings:

    assert np.allclose(normal_np(A, b), normal_jax(A, b))
    assert np.allclose(normal_np(A, b), normal_nb(A, b))
    %timeit normal_np(A, b)
    %timeit normal_jax(A, b).block_until_ready()
    %timeit normal_nb(A, b)
    

    Output:

    882 ms ± 6.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    81.5 ms ± 294 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    46.7 ms ± 101 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)