Search code examples
iteratorjuliasparse-matrix

How to iterate through all non zero values of a sparse matrix and normal matrix


I am using Julia and I want to iterate over the values of a matrix. This matrix can either be a normal matrix or a sparse matrix but I do not have the prior knowledge of that. I would like to create a code that would work in both cases and being optimised for both cases.

For simplicity, I did a example that computes the sum of the vector multiplied by a random value. What I want to do is actually similar to this but instead of being multiplied by a random number is actually an function that takes long time to compute.

myiterator(m::SparseVector) = m.nzval
myiterator(m::AbstractVector) = m

function sumtimesrand(m)
   a = 0.
   for i in myiterator(m)
      a += i * rand()
   end
   return a
end


I = [1, 4, 3, 5]; V = [1, 2, -5, 3];
Msparse = sparsevec(I,V)
M = rand(5)
sumtimesrand(Msparse)
sumtimesrand(M)

I want my code to work this way. I.e. most of the code is the same and by using the right iterator the code is optimised for both cases (sparse and normal vector).

My question is: is there any iterator that does what I am trying to achieve? In this case, the iterator returns the values but an iterator over the indices would work to.

Cheers, Dylan


Solution

  • I think you almost had what you are asking for? I.e., change your AbstractVector and SparseVector into AbstractArray and AbstractSparseArray. But maybe I am missing something? See MWE below:

    using SparseArrays
    using BenchmarkTools # to compare performance
    
    # note the changes here to "Array":
    myiterator(m::AbstractSparseArray) = m.nzval
    myiterator(m::AbstractArray) = m
    
    function sumtimesrand(m)
       a = 0.
       for i in myiterator(m)
          a += i * rand()
       end
       return a
    end
    
    N = 1000
    spV = sprand(N, 0.01); V = Vector(spV)
    spM = sprand(N, N, 0.01); M = Matrix(spM)
    
    @btime sumtimesrand($spV); #  0.044936 μs
    @btime sumtimesrand($V);   #  3.919    μs
    
    @btime sumtimesrand($spM); # 0.041678 ms
    @btime sumtimesrand($M);   # 4.095    ms