Search code examples
algorithmrecursionmatrixmatrix-multiplicationstrassen

Strassen's Subcubic Matrix Multiplication Algorithm with recursion


I am having an difficult time conceptualizing how to implement Strassen's version of this algorithm.

For background, I have the following pseudocode for the iterative version:

def Matrix(a,b):
    result = []

    for i in range(0,len(a)):
        new_array = []
        result.extend(new_array)
        for j in range(0,len(b[0])):
            ssum = 0
            for k in range(0,len(a[0])):
                ssum += a[i][k] * b[k][j]
            result[i][j] = ssum
    return result

I also have the following pseudocode for the initial divide and conquer version:

def recMatrix(x,y):
    if(len(x) == 1):
        return x[0] * y[0]

    z = []
    z[0] = recMatrix(x[0], y[0]) + recMatrix(x[1], y[2])
    z[1] = recMatrix(x[0], y[1]) + recMatrix(x[1], y[3])
    z[2] = recMatrix(x[2], y[0]) + recMatrix(x[3], y[2])
    z[3] = recMatrix(x[2], y[1]) + recMatrix(x[3], y[3])

    return z

So my question is, is my understanding of the divide and conquer method even correct, and if so, how can I modify to allow for Strassen's method? (This is not homework.)

(Specifically where I am having a hard time conceptualizing it is in the math on the entity itself prior (or after) the recursion. I.e. where P1 = A(F-H). If the recursion is actively multiplying the base elements, how is the strassen recursion taking care of the arithmetic (add and subtract) on the matrices? I have the following pseudocode to show where my brain is at:

def recMatrix(x,y):
    if(len(x) == 1):
        return x[0] * y[0]

    z = []
    p1 = recMatrix2(x[0], (y[1] - y[3]));
    p2 = recMatrix2(y[3], (x[0] + x[1]));
    p3 = recMatrix2(y[0], (x[2] + x[3]));
    p4 = recMatrix2(x[3], (y[2] - y[0]));
    p5 = recMatrix2((x[0] + x[3]), (y[0] + y[3]));
    p6 = recMatrix2((x[1] - x[3]), (y[2] + y[3]));
    p7 = recMatrix2((x[0] - x[3]), (y[0] + y[3]));

    z[0] = p5 + p4 - p2 + p6;
    z[1] = p1 + p2;
    z[2] = p3 + p4;
    z[3] = p1 + p5 - p3 - p7;

    return z

Solution

  • Found an implementation that does what I'm looking for... namely, it specifically shows how/when to recurse: https://github.com/MartinThoma/matrix-multiplication/blob/master/Python/strassen-algorithm.py

    #!/usr/bin/python
    # -*- coding: utf-8 -*-
    
    from optparse import OptionParser
    from math import ceil, log
    
    def read(filename):
        lines = open(filename, 'r').read().splitlines()
        A = []
        B = []
        matrix = A
        for line in lines:
            if line != "":
                matrix.append(map(int, line.split("\t")))
            else:
                matrix = B
        return A, B
    
    def printMatrix(matrix):
        for line in matrix:
            print "\t".join(map(str,line))
    
    def add(A, B):
        n = len(A)
        C = [[0 for j in xrange(0, n)] for i in xrange(0, n)]
        for i in xrange(0, n):
            for j in xrange(0, n):
                C[i][j] = A[i][j] + B[i][j]
        return C
    
    def subtract(A, B):
        n = len(A)
        C = [[0 for j in xrange(0, n)] for i in xrange(0, n)]
        for i in xrange(0, n):
            for j in xrange(0, n):
                C[i][j] = A[i][j] - B[i][j]
        return C
    
    def strassenR(A, B):
        """ Implementation of the strassen algorithm, similar to 
            http://en.wikipedia.org/w/index.php?title=Strassen_algorithm&oldid=498910018#Source_code_of_the_Strassen_algorithm_in_C_language
        """
        n = len(A)
    
        # Trivial Case: 1x1 Matrices
        if n == 1:
            return [[A[0][0]*B[0][0]]]
        else:
            # initializing the new sub-matrices
            newSize = n/2
            a11 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
            a12 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
            a21 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
            a22 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
    
            b11 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
            b12 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
            b21 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
            b22 = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
    
            aResult = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
            bResult = [[0 for j in xrange(0, newSize)] for i in xrange(0, newSize)]
    
            # dividing the matrices in 4 sub-matrices:
            for i in xrange(0, newSize):
                for j in xrange(0, newSize):
                    a11[i][j] = A[i][j];            # top left
                    a12[i][j] = A[i][j + newSize];    # top right
                    a21[i][j] = A[i + newSize][j];    # bottom left
                    a22[i][j] = A[i + newSize][j + newSize]; # bottom right
    
                    b11[i][j] = B[i][j];            # top left
                    b12[i][j] = B[i][j + newSize];    # top right
                    b21[i][j] = B[i + newSize][j];    # bottom left
                    b22[i][j] = B[i + newSize][j + newSize]; # bottom right
    
            # Calculating p1 to p7:
             aResult = add(a11, a22)
             bResult = add(b11, b22)
            p1 = strassen(aResult, bResult) # p1 = (a11+a22) * (b11+b22)
    
            aResult = add(a21, a22)      # a21 + a22
            p2 = strassen(aResult, b11)  # p2 = (a21+a22) * (b11)
    
            bResult = subtract(b12, b22) # b12 - b22
            p3 = strassen(a11, bResult)  # p3 = (a11) * (b12 - b22)
    
            bResult = subtract(b21, b11) # b21 - b11
            p4 =strassen(a22, bResult)   # p4 = (a22) * (b21 - b11)
    
            aResult = add(a11, a12)      # a11 + a12
            p5 = strassen(aResult, b22)  # p5 = (a11+a12) * (b22)   
    
            aResult = subtract(a21, a11) # a21 - a11
            bResult = add(b11, b12)      # b11 + b12
            p6 = strassen(aResult, bResult) # p6 = (a21-a11) * (b11+b12)
    
            aResult = subtract(a12, a22) # a12 - a22
            bResult = add(b21, b22)      # b21 + b22
            p7 = strassen(aResult, bResult) # p7 = (a12-a22) * (b21+b22)
    
            # calculating c21, c21, c11 e c22:
            c12 = add(p3, p5) # c12 = p3 + p5
            c21 = add(p2, p4)  # c21 = p2 + p4
    
            aResult = add(p1, p4) # p1 + p4
            bResult = add(aResult, p7) # p1 + p4 + p7
            c11 = subtract(bResult, p5) # c11 = p1 + p4 - p5 + p7
    
            aResult = add(p1, p3) # p1 + p3
            bResult = add(aResult, p6) # p1 + p3 + p6
            c22 = subtract(bResult, p2) # c22 = p1 + p3 - p2 + p6
    
            # Grouping the results obtained in a single matrix:
            C = [[0 for j in xrange(0, n)] for i in xrange(0, n)]
            for i in xrange(0, newSize):
                for j in xrange(0, newSize):
                    C[i][j] = c11[i][j]
                    C[i][j + newSize] = c12[i][j]
                    C[i + newSize][j] = c21[i][j]
                    C[i + newSize][j + newSize] = c22[i][j]
             return C
    
    def strassen(A, B):
        assert type(A) == list and type(B) == list
        assert len(A) == len(A[0]) == len(B) == len(B[0])
    
        nextPowerOfTwo = lambda n: 2**int(ceil(log(n,2)))
        n = len(A)
        m = nextPowerOfTwo(n)
        APrep = [[0 for i in xrange(m)] for j in xrange(m)]
        BPrep = [[0 for i in xrange(m)] for j in xrange(m)]
        for i in xrange(n):
            for j in xrange(n):
                APrep[i][j] = A[i][j]
                BPrep[i][j] = B[i][j]
        CPrep = strassenR(APrep, BPrep)
        C = [[0 for i in xrange(n)] for j in xrange(n)]
        for i in xrange(n):
            for j in xrange(n):
                C[i][j] = CPrep[i][j]
        return C
    
    if __name__ == "__main__":
        parser = OptionParser()
        parser.add_option("-i", dest="filename", default="2000.in",
             help="input file with two matrices", metavar="FILE")
        (options, args) = parser.parse_args()
    
        A, B = read(options.filename)
        C = strassen(A, B)
        printMatrix(C)