Search code examples
pythonnumpysparse-matrix

Optimizing iteration using itertools.izip


I have an algorithm that iterates through all of the nonzero values of a matrix which looks something like this:

for row, col, val in itertools.izip(matrix.row, matrix.col, matrix.data):
    dostuff(row, col, val)

I realize that this is the fastest way of iterating through a sparse matrix in numpy, as was discussed in Iterating through a scipy.sparse vector (or matrix).

My problem is that the function I perform at each evaluation takes another vector, let's call it vec, and does nothing if vec[row] is equal to 0, which in some cases is true for the majority of rows.

Therefore, I do not want to iterate through all the nonzero triplets (row, col, val) in the matrix for which vec[row] != 0.

What I currently do is the simple and stupid solution of

import numpy as np
import scipy.sparse as sp
import itertools

N = 10000
matrix = sp.rand(N, N, density=0.0001, format='coo', dtype=None, random_state=None)
vec = np.zeroes(N)
s = 0
for row, col, val in itertools.izip(matrix.row, matrix.col, matrix.data):
    if vec[row] != 0:
        s += vec[row] * val # in reality, some other function is here

which works and works quicker than the original code if there are not many rows with vec[row]!=0. However, the code runs slowly in the case when all values of vec are nonzero, and that is a case I am not allowed to ignore (for example, if vec=np.ones(len(matrix.data)).

Therefore, I need some sort of extension of izip which will allow me to "conditionally" iterate through its output, so that I would write something like

for row, col, val in itertools.izip(matrix.row, matrix.col, matrix.data, lambda x: vec[x[0]] !> 0):
    dostuff(row, col, val)

What are your suggestions? What will be the quickest way to do this?


Solution

  • You can just use Numpy's special indexing on the rows, columns and data:

    which = vec[matrix.row] != 0
    
    rows = matrix.row[which]
    cols = matrix.col[which]
    data = matrix.data[which]
    
    for row, col, val in itertools.izip(rows, cols, data):
        s += vec[row] * val