Search code examples
pythonrecursionmatrixmultiplication

Multiply 2x2 Matrix in Python Recursively? No nested loops or numpy


I'm having a difficult time finding a very simple program that multiplies two 2x2 matrices recursively. Can anyone help me out? Just need to multiply X and Y without numpy or using nested loops.

X = [[1, 2],
     [2, 3]]

Y = [[2, 3],
     [3, 4]]

FWIW - Here's my naïve method :)

X = [[1, 2],  
    [2, 3]]


Y = [[2, 3],  
     [3, 4]]

result = [[0, 0],  
          [0, 0]]

for i in range(len(X)):
    for j in range(len(Y[0])):
        for k in range(len(Y)):
            result[i][j] += X[i][k] * Y[k][j]

for r in result:
    print(r)  

# ANS = [8, 11], [13, 18]

_________________________ Per comments below - Recursive Strassen's

import numpy as np
def split(matrix):
    row, col = matrix.shape
    row2, col2 = row // 2, col // 2
    return matrix[:row2, :col2], matrix[:row2, col2:], matrix[row2:, :col2], matrix[row2:, col2:]


def strassen_recur(x, y):
    if len(x) == 1:
        return x * y

    a, b, c, d = split(x)
    e, f, g, h = split(y)

    p1 = strassen_recur(a, f - h)
    p2 = strassen_recur(a + b, h)
    p3 = strassen_recur(c + d, e)
    p4 = strassen_recur(d, g - e)
    p5 = strassen_recur(a + d, e + h)
    p6 = strassen_recur(b - d, g + h)
    p7 = strassen_recur(a - c, e + f)

    c1 = (p5 + p4 - p2 + p6)
    c2 = (p1 + p2)
    c3 = (p3 + p4)
    c4 = (p1 + p5 - p3 - p7)

    c = np.vstack((np.hstack((c1, c2)), np.hstack((c3, c4))))

    return c

print(strassen_recur(x, y))

I also have a naive Strassen method written as well. But like I mentioned, I was just hoping somebody had something quick to show me so I didn't have to spend a lot of time trying to figure it out. It's all good.


Solution

  • Here's the answer in the event anyone is ever looking for a recursive solution to multiplying two matrices that are 2x2... or 3x3, 4x4 whatever. You would just have to change your number of rows/columns and add additional for-loops. It's not the prettiest, but it works. Maybe someone out there can make it even better?

    
    X = [[1, 2],
         [2, 3]]
    
    Y = [[2, 3],
         [3, 4]]
    
    result = [[0, 0],
              [0, 0]]
    i = 0
    j = 0
    k = 0
    
    def multiplyMatrixRec(row1, col1, X, row2, col2, Y, result):
    
        if j < col2:
            if k < col1:
                result[i][j] += X[i][k] * Y[k][j]
                k += 1
                multiplyMatrixRec(row1, col1, X, row2, col2, Y, result)
            j += 1
            multiplyMatrixRec(row1, col1, X, row2, col2, Y, result)
        i += 1
        multiplyMatrixRec(row1, col1, X, row2, col2, Y, result)
    
    
    def multiplyMatrix(row1, col1, X, row2, col2, Y):
        for i in range(row1):
            for j in range(col2):
                print(result[i][j], end=" ")
            print()
    
    row1 = 2
    col1 = 2
    row2 = 2
    col2 = 2
    multiplyMatrix(row1, col1, X, row2, col2, Y)
    

    Output:

    8 11
    13 18

    Cheers!