NUMBA does not support sparse matrix and I would like to find a way to write my own function for spM multiplication in COO format.
import numpy as np
from numba import njit, prange
from numba.core.types import ListType, int32, DictType, float64
from numba.typed import Dict, List
from scipy.sparse import coo_matrix
LIST_TYPE_TRIP = ListType(int32)
DICT_TYPE_TRIP = DictType(keyty=int32, valty=float64)
TOLERANCE = 1e-6
@njit
def triplet2lookup(bc: np.ndarray, m: int) -> tuple[np.ndarray, np.ndarray]:
"""
:param bc: column index of sparse matrix b
:param m: num of cols in sparse matrix b
:return:
"""
table = np.zeros((m, m), dtype=np.int32)
count = np.zeros(m, dtype=np.int32)
for i in range(len(bc)):
val = bc[i]
table[val, count[val]] = i
count[val] += 1
return table, count
@njit(fastmath=True)
def _mul(ar, ac, av, br, bc, bv, an, bm):
na = len(av)
rr = np.empty(len(ar) * 2, dtype=np.int32)
rc = np.empty(len(rr), dtype=np.int32)
rv = np.empty(len(rr))
table, count = triplet2lookup(bc, bm)
cnt = 0
hash_mat = np.zeros((an, bm))
for i in range(na):
target = ac[i]
for j in range(count[target]):
row_idx = ar[i]
col_idx = br[table[target, j]]
rr[cnt] = row_idx
rc[cnt] = col_idx
rv[cnt] = av[i] * bv[table[target, j]]
cnt += 1
if hash_mat[row_idx, col_idx] < TOLERANCE:
rr[cnt] = row_idx
rc[cnt] = col_idx
rv[cnt] = av[i] * bv[table[target, j]]
cnt += 1
else:
rv[cnt] += av[i] * bv[table[target, j]]
if cnt >= len(rr):
rr, rc, rv = extend_arr(rr, rc, rv)
rr = rr[:cnt]
rc = rc[:cnt]
rv = rv[:cnt]
return rr, rc, rv
@njit
def _extend_arr(rr, dtype):
tmp_rr = np.empty(len(rr) * 2, dtype=dtype)
tmp_rr[:len(rr)] = rr
return tmp_rr
@njit
def extend_arr(rr, rc, rv):
rr = _extend_arr(rr, np.int32)
rc = _extend_arr(rc, np.int32)
rv = _extend_arr(rv, np.float64)
return rr, rc, rv
After comparing the speed against scipy, I found scipy is much faster for big matrices.
n = 50000
m = 1000
row_a = np.random.randint(0, m, n)
col_a = np.random.randint(0, m, n)
val_a = np.random.random(n)
row_b = np.random.randint(0, m, n)
col_b = np.random.randint(0, m, n)
val_b = np.random.random(n)
a_sci = coo_matrix((val_a, (row_a, col_a)), shape=(m, m))
b_sci = coo_matrix((val_b, (row_b, col_b)), shape=(m, m))
%timeit scipy_res = a_sci @ b_sci
# 21.4 ms ± 722 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit rr, rc, rv = _mul(row_a, col_a, val_a, col_b, row_b, val_b, m, m)
202 ms ± 47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
On my machine, scipy is almost 10 times faster. So I wonder if there is any thing like SIMD or parallel computation that can be applied to speed up the _mul function?
I have tried to parallelize the _mul function by changing @njit to @njit(parallel=True) and range -> prange for either inner or outer loops but it slows down the algorithm.
I figured out a fairly simple way of doing this which is only 15% slower than SciPy, and still allows you to use nopython mode. The basic idea is to call SciPy's matrix multiply.
If you do this in the most direct way, using @jit(forceobj=True)
, it has a problem: any function which calls _mul()
cannot determine its return type, and therefore cannot be run in nopython mode. Any function that calls that function cannot be run in nopython mode, etc.
I tried working around this, by specifying the return type with a signature, but I couldn't figure out how to get njit()
functions to be able to use it. Eventually, I found nb.objmode()
, which can specify a block of code which is allowed to run Python functions within a broader Numba function. Functions which call this function can still be marked with njit()
.
import numba as nb
@nb.njit()
def mul2(ar, ac, av, br, bc, bv, n):
with nb.objmode(row='i4[:]', col='i4[:]', data='f8[:]'):
a_sci = coo_matrix((av, (ar, ac)), shape=(n, n))
b_sci = coo_matrix((bv, (br, bc)), shape=(n, n))
result = (a_sci @ b_sci).tocoo()
row = result.row
col = result.col
data = result.data
return row, col, data
Warning: The documentation for objmode says that there is usually a better way of accomplishing your goals than using objmode. There may be a better way of doing this - I don't have a lot of experience in Numba.
Why is this slower than calling SciPy directly? The problem is that you want the result in COO format, but the matrix multiply produces results in CSR format. Converting takes extra time.
I also experimented with an option that outputs in CSR.
@nb.njit()
def mul3(ar, ac, av, br, bc, bv, n):
with nb.objmode(data='f8[:]', indptr='i4[:]', indices='i4[:]'):
a_sci = coo_matrix((av, (ar, ac)), shape=(n, n))
b_sci = coo_matrix((bv, (br, bc)), shape=(n, n))
result = a_sci @ b_sci
data = result.data
indptr = result.indptr
indices = result.indices
return data, indices, indptr
This version is only 2% slower than the direct SciPy option.
These functions benefit from the optimized algorithms present in SciPy, while being able to be called from Numba functions that use nopython mode for all other operations.
Timings:
Pure SciPy
27.8 ms ± 471 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Original Numba
184 ms ± 594 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Numba SciPy COO output
32 ms ± 228 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Numba SciPy CSR output
28.3 ms ± 685 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)