Search code examples
haskellrecursionmatrixvectorrow-major-order

Matrix multiplication of row-major recursively


I'm programming my own matrix module for fun and practice (Time and space complexity does not matter). Now I want to implement matrix multiplication and I am struggling with it. It probably the reason I am using Haskell and I haven't had much experience with it. This is my data type:

data Matrix a =
M {
  rows::Int,
  cols::Int,
  values::[a]
}

Which stores a 3x2 Matrix like this in array:

1 2
3 4
5 6
= [1,2,3,4,5,6]

I have a somewhat working transpose function

transpose::(Matrix a)->(Matrix a)
transpose (M rows cols values) = M cols rows (aux values 0 0 [])
  where
   aux::[a]->Int->Int->[a]->[a]
   aux values row col transposed 
     | cols > col =
       if rows > row then 
         aux values (row+1) col (transposed ++ [valueAtIndex (M rows cols values) (row,col)])
       else aux values 0 (col+1) transposed
     | otherwise = transposed

To index the elements in the array I am using this function

valueAtIndex::(Matrix a)->(Int, Int)->a
valueAtIndex (M rows cols values) (row, col) 
  | rows <= row || cols <= col = error "indices too large for given Matrix"
  | otherwise = values !! (cols * row + col)

From my understanding, I have to get elements like this for m1: 2x3 and m2: 3x2

m1(0,0)*m2(0,0)+m1(0,1)*m2(0,1)+m1(0,2)*m2(0,2)
m1(0,0)*m2(1,0)+m1(0,1)*m2(1,1)+m1(0,2)*m2(1,2)
m1(1,0)*m2(0,0)+m1(1,1)*m2(0,1)+m1(1,2)*m2(0,2)
m1(1,0)*m2(1,0)+m1(1,1)*m2(1,1)+m1(1,2)*m2(2,2)

Now I need a function that takes two matrices, with rows m1 == cols m2 and then somehow recursively calculate the correct matrix.

multiplyMatrix::Num a=>(Matrix a)->(Matrix a)->(Matrix a)

Solution

  • First of all, I'm not really convinced that such linear list is a good idea. A list in Haskell is modelled as a linked list. So that means that typically accessing the k-th element, will run in O(k). So for an m×n-matrix that means it takes O(m n) in order to access the last element. By using a 2d linked list: a linked list that contains linked lists, we scale that down to O(m+n), which is typically faster. Yes there is some overhead since you use more "cons" data constructors, but the amount of traversing is typically lower. In case you really want fast access, you should use arrays, vectors, etc. But then there are other design decisions to make.

    So I propose we model the matrix as:

    data Matrix a = M {
      rows :: Int,
      cols :: Int,
      values :: [[a]]
    }

    Now with this data constructor we can define a transpose as:

    transpose' :: Matrix a -> Matrix a
    transpose' (M r c as) = M c r (trans as)
        where trans [] = []
              trans xs = map head xs : trans (map tail xs)
    

    (here we assume that the list of lists is always rectangular)

    So now for the matrix multiplication. If A and B are two matrices, and C = A × B, then that basically means that ai, j is the dot product of the i-th row of A, and the j-th column of B. Or the i-th row of A, and the j-th row of BT (the transpose of B). We can thus define the dot product as:

    dot_prod :: Num a => [a] -> [a] -> a
    dot_prod xs ys = sum (zipWith (*) xs ys)
    

    and now it is only a matter of iterating through the rows and columns, and placing the elements in the right list. Like:

    mat_mul :: Num a => Matrix a -> Matrix a -> Matrix a
    mat_mul (M r ca xss) m2 | ca /= ra = error "Invalid matrix shapes"
                            | otherwise = M r c (matmul xss)
        where (M c rb yss) = transpose m2
              matmul [] = []
              matmul (xs:xss) = generaterow yss xs : matmul xss
              generaterow [] _ = []
              generaterow (ys:yss) xs = dot_prod xs ys : generaterow yss xs