Search code examples
performancealgorithmmatrixsparse-matrixmatrix-multiplication

Special case of sparse matrices multiplication


I'm trying to come up with fast algorithm to find result of AtLA operation, where

  • L - is symmetric n x n matrix with real numbers.
  • A - is sparse n x m matrix, m < n. Each row has one and only one non-zero element, and it's equal to 1. It's also guaranteed that every column has at most two non-zero elements.

I come up with one algorithm, but I feel like there should be something faster than this.

Let's represent every column of A as pair of row numbers with non-zero elements. If a column has only one non-zero element, its row number listed twice. E.g. for the following matrix

Sparse matrix example

Such representation would be

column 0: [0, 2]; column 1: [1, 3]; column 2: [4, 4]

Or we can list it as a single array: A = [0, 2, 1, 3, 4, 4]; Now, L' = LA can be calculated as:

for (i = 0; i < A.length; i += 2):
  if A[i] != A[i + 1]:
     # sum of two column vectors, i/2-th column of L'
     L'[i/2] = L[A[i]] + L[A[i + 1]] 
  else:
     L'[i/2] = L[A[i]]

To calculate L''=AtL' we do it one more time:

for (i = 0; i < A.length; i += 2):
  if A[i] != A[i + 1]:
    # sum of two row vectors, i/2-th row of L''
    L''[i/2] = L'[A[i]] + L'[A[i + 1]]
  else:
    L''[i/2] = L'[A[i]]

The time complexity of such approach is O(mn + mn), and space complexity (to get final AtLA result) is O(nn). I'm wondering if it's possible to improve it to O(mm) in terms of space and/or performance?


Solution

  • The second loop combines at most 2m rows of L', so if m is much smaller than n there will be several rows of L' that are never used.

    One way to avoid calculating and storing these unused entries is to change your first loop into a function and only calculate the individual elements of L' as they are needed.

    def L'(row,col):
      i=col*2
      if A[i] != A[i + 1]:
        # sum of two column vectors, i/2-th column of L'
        return L[row][A[i]] + L[row][A[i + 1]] 
      else:
        return L[row][A[i]]
    
    for (i = 0; i < A.length; i += 2):
      if A[i] != A[i + 1]:
        for (k=0;k<m;k++):
          L''[i/2][k] = L'(A[i],k) + L'(A[i + 1],k)
      else:
        for (k=0;k<m;k++):
          L''[i/2][k] = L'(A[i],k)
    

    This should then have space and complexity O(m*m)