Search code examples
pythonmatrixscipysparse-matrixnumba

Is there any way to speed up my NUMBA implementation of sparse matrix multiplication?


NUMBA does not support sparse matrix and I would like to find a way to write my own function for spM multiplication in COO format.

import numpy as np
from numba import njit, prange
from numba.core.types import ListType, int32, DictType, float64
from numba.typed import Dict, List
from scipy.sparse import coo_matrix

LIST_TYPE_TRIP = ListType(int32)
DICT_TYPE_TRIP = DictType(keyty=int32, valty=float64)
TOLERANCE = 1e-6

@njit
def triplet2lookup(bc: np.ndarray, m: int) -> tuple[np.ndarray, np.ndarray]:
    """
    :param bc: column index of sparse matrix b
    :param m: num of cols in sparse matrix b
    :return:
    """
    table = np.zeros((m, m), dtype=np.int32)
    count = np.zeros(m, dtype=np.int32)
    for i in range(len(bc)):
        val = bc[i]
        table[val, count[val]] = i
        count[val] += 1
    return table, count


@njit(fastmath=True)
def _mul(ar, ac, av, br, bc, bv, an, bm):
    na = len(av)
    rr = np.empty(len(ar) * 2, dtype=np.int32)
    rc = np.empty(len(rr), dtype=np.int32)
    rv = np.empty(len(rr))
    table, count = triplet2lookup(bc, bm)
    cnt = 0
    hash_mat = np.zeros((an, bm))
    for i in range(na):
        target = ac[i]
        for j in range(count[target]):
            row_idx = ar[i]
            col_idx = br[table[target, j]]
            rr[cnt] = row_idx
            rc[cnt] = col_idx
            rv[cnt] = av[i] * bv[table[target, j]]
            cnt += 1
            if hash_mat[row_idx, col_idx] < TOLERANCE:
                rr[cnt] = row_idx
                rc[cnt] = col_idx
                rv[cnt] = av[i] * bv[table[target, j]]
                cnt += 1
            else:
                rv[cnt] += av[i] * bv[table[target, j]]
            if cnt >= len(rr):
                rr, rc, rv = extend_arr(rr, rc, rv)

    rr = rr[:cnt]
    rc = rc[:cnt]
    rv = rv[:cnt]
    return rr, rc, rv


@njit
def _extend_arr(rr, dtype):
    tmp_rr = np.empty(len(rr) * 2, dtype=dtype)
    tmp_rr[:len(rr)] = rr
    return tmp_rr


@njit
def extend_arr(rr, rc, rv):
    rr = _extend_arr(rr, np.int32)
    rc = _extend_arr(rc, np.int32)
    rv = _extend_arr(rv, np.float64)
    return rr, rc, rv

After comparing the speed against scipy, I found scipy is much faster for big matrices.

n = 50000
m = 1000
row_a = np.random.randint(0, m, n)
col_a = np.random.randint(0, m, n)
val_a = np.random.random(n)

row_b = np.random.randint(0, m, n)
col_b = np.random.randint(0, m, n)
val_b = np.random.random(n)

a_sci = coo_matrix((val_a, (row_a, col_a)), shape=(m, m))
b_sci = coo_matrix((val_b, (row_b, col_b)), shape=(m, m))
%timeit scipy_res = a_sci @ b_sci
# 21.4 ms ± 722 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit rr, rc, rv = _mul(row_a, col_a, val_a, col_b, row_b, val_b, m, m)
202 ms ± 47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

On my machine, scipy is almost 10 times faster. So I wonder if there is any thing like SIMD or parallel computation that can be applied to speed up the _mul function?

I have tried to parallelize the _mul function by changing @njit to @njit(parallel=True) and range -> prange for either inner or outer loops but it slows down the algorithm.


Solution

  • I figured out a fairly simple way of doing this which is only 15% slower than SciPy, and still allows you to use nopython mode. The basic idea is to call SciPy's matrix multiply.

    If you do this in the most direct way, using @jit(forceobj=True), it has a problem: any function which calls _mul() cannot determine its return type, and therefore cannot be run in nopython mode. Any function that calls that function cannot be run in nopython mode, etc.

    I tried working around this, by specifying the return type with a signature, but I couldn't figure out how to get njit() functions to be able to use it. Eventually, I found nb.objmode(), which can specify a block of code which is allowed to run Python functions within a broader Numba function. Functions which call this function can still be marked with njit().

    import numba as nb
    
    @nb.njit()
    def mul2(ar, ac, av, br, bc, bv, n):
        with nb.objmode(row='i4[:]', col='i4[:]', data='f8[:]'):
            a_sci = coo_matrix((av, (ar, ac)), shape=(n, n))
            b_sci = coo_matrix((bv, (br, bc)), shape=(n, n))
            result = (a_sci @ b_sci).tocoo()
            row = result.row
            col = result.col
            data = result.data
        return row, col, data
    

    Warning: The documentation for objmode says that there is usually a better way of accomplishing your goals than using objmode. There may be a better way of doing this - I don't have a lot of experience in Numba.

    Why is this slower than calling SciPy directly? The problem is that you want the result in COO format, but the matrix multiply produces results in CSR format. Converting takes extra time.

    I also experimented with an option that outputs in CSR.

    @nb.njit()
    def mul3(ar, ac, av, br, bc, bv, n):
        with nb.objmode(data='f8[:]', indptr='i4[:]', indices='i4[:]'):
            a_sci = coo_matrix((av, (ar, ac)), shape=(n, n))
            b_sci = coo_matrix((bv, (br, bc)), shape=(n, n))
            result = a_sci @ b_sci
            data = result.data
            indptr = result.indptr
            indices = result.indices
        return data, indices, indptr
    

    This version is only 2% slower than the direct SciPy option.

    These functions benefit from the optimized algorithms present in SciPy, while being able to be called from Numba functions that use nopython mode for all other operations.

    Timings:

    Pure SciPy
    27.8 ms ± 471 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    
    Original Numba
    184 ms ± 594 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    Numba SciPy COO output
    32 ms ± 228 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    
    Numba SciPy CSR output
    28.3 ms ± 685 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)