Search code examples
algorithmmatrix-multiplicationstrassen

Trouble With Implementing Strassen's Algorithm for Matrix Multiplication


I've been trying to implement Strassen's algorithm for matrix multiplication for the past couple of hours and have had trouble getting the correct product. I think one of the my helper functions (helpSub,createProd, helpProduct) may be the issue or the format of my strass2 function (order of commands, etc). Any tips would be welcome because I'm totally stumped. I've been using two 4 x 4 matrices as tester matrices. I've tried tons of variations of p1-p7 and c1-c4 that I've seen on the internet but none seem to work. Below is the class I've created.

 /* @author williamnewman

public class strassen2 {

//Main Strassen multiplication function
//BASE CASE:
int [][] strass2(int[][] x, int[][]y){
    if(x.length == 1 && y.length == 1){
        System.out.println("Donezo");
        int [][] nu = new int[1][1];
        nu[0][0] = x[0][0] * y[0][0];
        return nu;

    }
    else{
   int[][] a,b,c,d,e,f,g,h;
   int dim = x.length/2;

//Dividing two matrices into 8 sub matrices
  System.out.println("A<B<C");
   a = helpSub(0,0,x);
   C(a);
   b = helpSub(0,dim,x);

   C(b);
   c = helpSub(dim,0,x);
   C(c);
   d = helpSub(dim,dim,x);
   C(d);
   e = helpSub(0,0,y);
   C(e);
   f = helpSub(0,dim,y);
   C(f);
   g = helpSub(dim,0,y);
   C(g);
   h = helpSub(dim,dim,y);
   C(h);

   int[][] p1,p2,p3,p4,p5,p6,p7;


//Creating p1-p7
   /
   p1 = strass2(a,subtract(f,h));
   p2 = strass2(h, add(a,b));
   p3 = strass2(e,add(c,d));
   p4 = strass2(d,subtract(g,e));
   p5 = strass2(add(a,d),add(e,h));
   p6 = strass2(subtract(b,d),add(g,h));
   p7 = strass2(subtract(a,c),add(e,f));
   int [][] prod;
   int [][] c1,c2,c3,c4;

//Creating c1-c4
   c1 = subtract(add(p6,p5),subtract(p4,p2));
   c2 = add(p1,p2);
   c3 = add(p3,p4);
   c4 = subtract(add(p1,p5),subtract(p3,p7));
   C(c1);
   System.out.println("C1::");
   C(c2);
   System.out.println("C2::");
   C(c3);
   System.out.println("C3::");
   C(c4);
   System.out.println("C4::");
//CREATES PRODUCT MATRIX
   prod = createProd(c1,c2,c3,c4);
   return prod;

    }




}

//Creates product matrix from c1-c4
int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
    int[][] product = new int[c1.length*2][c1.length*2];
    int mid = c1.length;
    int fin = c1.length * 2;
    helpProduct(0,0,mid,mid,product,c1);
    helpProduct(0,mid,mid,fin,product,c2);
    helpProduct(mid,0,fin,mid,product,c3);
    helpProduct(mid,mid,fin,fin,product,c4);

     System.out.println();
    System.out.println("PRODUCT::!:");
    C(product);
    return product;



}

    //Helper function to create larger matrix from submatrices
void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
    int indR = 0;
    int indC = 0;
    for(int i = x; i < z1; i++){
        indC = 0;
        for(int j = y; j < z2; j++){
            product[i][j] = a1[indR][indC];
            indC++;
        }
        indR++;
    }
}


    int[][] helpSub(int x, int y, int[][] mat){
    int[][] sub = new int[mat.length/2][mat.length/2];
    for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
    for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
    {
            sub[i1][j1] = mat[i2][j2];
                           // System.out.println(sub[i1][j1]);
    }
    return sub;
}



//Normal Matrix Multiplication Function
int[][] multiply(int[][]a,int[][]b){
    MM nu = new MM(a,b);
    return nu.product;
}

    //Adds one matrix to the next
int[][] add(int[][]a, int[][]b){
    int [][] nu = new int[a.length][a[0].length];
    for(int i = 0; i < a.length; i++){
        for(int j = 0; j < a[i].length;j++){
            nu[i][j] = a[i][j] + b[i][j];
        }
    }
    return nu;
}

//Subtracts second matrix from the first
int[][] subtract(int[][] a, int[][] b){
    int [][] sub = new int[a.length][a.length];
    //System.out.println("made it");
    for(int i = 0; i < a.length; i++){
        for(int j = 0; j < a[i].length;j++){
            sub[i][j] = a[i][j] - b[i][j];
        }
    }
    return sub;
}
//Prints the matrix
 void C(int[][] product){
    for(int i = 0; i <product.length; i++){
        for(int j = 0; j < product[i].length; j++){
            System.out.print(product[j][i]  + " ");

        }
        System.out.println();
    }
}
}

If anything is confusing let me know and I'll update the question!

Here is the main function::

      public static void main(String[] args) {
        int [][]a = {{1,2,3,4},
            {4,3,2,1},
            {1,2,3,4},
            {4,3,2,1}};

        int [][]b = {{3,4,5,6},
            {3,4,5,6},
            {5,4,3,2},
            {5,4,3,2}
        };
        MM a1 = new MM(a,b);
        a1.C();
        int[][] prod;
        System.out.println("----");
        strassen2 a2 = new strassen2();
        prod = a2.strass2(a,b);
        a2.C(prod);
    }

}

Here are the results so far (the expected result is the first 4x4 matrix shown and the actual result is the last 4x4 matrix shown):

EXPECTED:

44 40 36 32 
36 40 44 48 
44 40 36 32 
36 40 44 48 
----


ACTUAL::
70 78 50 42 
86 86 34 34 
30 38 30 38 
38 54 38 54 

I'm pretty sure my helpSub() functions work because they produced the corrected a-h. However, there might be a problem with the parameters I use in the strass2 recursive calls. I'm sorry if it's not specific enough I'm just a bit burnt out on it and was curious if anyone saw any glaring issues.


Solution

  • Sorry for being vague but it seems I have solved the question. I used the formulas from this website for p1-p7 and c1-c4. ([Formulas for Strassen's Matrix Multiplication][1]

    [1]: http://www.stoimen.com/blog/2012/11/26/computer-algorithms-strassens-matrix-multiplication/ )

    After implementing those formulas the product matrices were nearly correct but 4 or the values were off. I then changed the base case to when x and y ' s length equaled to two and that seemed to correct the 4 values that were off. For those that are curious here is my modified code for the strassen2 class.

    /*
     * To change this license header, choose License Headers in Project Properties.
     * To change this template file, choose Tools | Templates
     * and open the template in the editor.
     */
    package pkg2a;
    
    /**
     *
     * @author williamnewman
     */
    public class strassen2 {
    
        int [][] strass2(int[][] x, int[][]y){
            if(x.length <= 2 && y.length <= 2){ //!!!! MODIFICATION HERE !!
                return multiply(x,y);
    
            }
            else{
           int[][] a,b,c,d,e,f,g,h;
           int dim = x.length/2;
    
          System.out.println("A<B<C");
           a = helpSub(0,0,x);
           //C(a);
           b = helpSub(0,dim,x);
    
           //C(b);
           c = helpSub(dim,0,x);
           //C(c);
           d = helpSub(dim,dim,x);
           //C(d);
           e = helpSub(0,0,y);
           //C(e);
           f = helpSub(0,dim,y);
           //C(f);
           g = helpSub(dim,0,y);
           //C(g);
           h = helpSub(dim,dim,y);
           //C(h);
    
           int[][] p1,p2,p3,p4,p5,p6,p7;
          // createSub(x,y,a,b,c,d,e,f,g,h);
          int[] s1,s2,s3,s4,s5,s6,s7,s8,s9,s10; 
    
          //MODIFICATION HERE
           p1 = strass2(a,subtract(f,h));
           p2 = strass2(add(a,b),h);
           p3 = strass2(add(c,d),e);
           p4 = strass2(d,subtract(g,e));
           p5 = strass2(add(a,d),add(e,h));
           p6 = strass2(subtract(b,d),add(g,h));
           p7 = strass2(subtract(a,c),add(e,f));
           int [][] prod;
           int [][] c1,c2,c3,c4;
           c1 = subtract(add(p5,p4),subtract(p2,p6));
           c2 = add(p1,p2);
           c3 = add(p3,p4);
           c4 = subtract(add(p1,p5),add(p3,p7));
           //C(c1);
           //System.out.println("C1::");
           //C(c2);
           //System.out.println("C2::");
           //C(c3);
           //System.out.println("C3::");
           //C(c4);
           //System.out.println("C4::");
           prod = createProd(c1,c2,c3,c4);
           return prod;
    
            }
    
    
    
    
        }
    
        int[][] createProd(int[][] c1, int[][] c2, int[][] c3, int[][] c4){
            int[][] product = new int[c1.length*2][c1.length*2];
            int mid = c1.length;
            int fin = c1.length * 2;
            helpProduct(0,0,mid,mid,product,c1);
            helpProduct(0,mid,mid,fin,product,c2);
            helpProduct(mid,0,fin,mid,product,c3);
            helpProduct(mid,mid,fin,fin,product,c4);
    
             System.out.println();
            System.out.println("PRODUCT::!:");
            //C(product);
            return product;
    
    
    
        }
    
            //Helper function to create larger matrix from submatrices
        void helpProduct(int x, int y, int z1, int z2,int[][] product, int[][] a1){
            int indR = 0;
            int indC = 0;
            for(int i = x; i < z1; i++){
                indC = 0;
                for(int j = y; j < z2; j++){
                    product[i][j] = a1[indR][indC];
                    indC++;
                }
                indR++;
            }
        }
    
        /*
            void createSub(int[][]x, int[][]y,int[][] a,int[][] b,int[][] c, int[][] d, int[][] e, int[][] f, int [][] g, int[][] h){
           int div1R = x.length/2;
           int div1C = div1R;
           int div2R = div1R;
           int div2C = div1R;
           a = helpSub(0,0,div1R,div1C,x);
          // c(a);
           b = helpSub(0,div1C,div1R,x[0].length,x);
           //c(b);
           c = helpSub(div1R,0,x.length,div1C,x);
           //c(c);
           d = helpSub(div1R,div1C,x.length,x[0].length,x);
           //c(d);
           e = helpSub(0,0,div2R,div2C,y);
           //c(e);
           f = helpSub(0,div2C,div2R,y[0].length,y);
          // c(f);
           g = helpSub(div2R,0,y.length,div2C,y);
           //c(g);
           h = helpSub(div2R,div2C,y.length,y[0].length,y);
          // c(h);
    
    
        }
            */
            int[][] helpSub(int x, int y, int[][] mat){
            int[][] sub = new int[mat.length/2][mat.length/2];
            for(int i1 = 0, i2=x; i1 < (mat.length/2); i1++, i2++)
            for(int j1 = 0, j2=y; j1<(mat.length/2); j1++, j2++)
            {
                    sub[i1][j1] = mat[i2][j2];
                                   // System.out.println(sub[i1][j1]);
            }
            return sub;
        }
    
    
        int[][] multiply(int[][]a,int[][]b){
            MM nu = new MM(a,b);
            return nu.product;
        }
    
            //Adds one matrix to the next
        int[][] add(int[][]a, int[][]b){
            int [][] nu = new int[a.length][a[0].length];
            for(int i = 0; i < a.length; i++){
                for(int j = 0; j < a[i].length;j++){
                    nu[i][j] = a[i][j] + b[i][j];
                }
            }
            return nu;
        }
    
        //Subtracts second matrix from the first
        int[][] subtract(int[][] a, int[][] b){
            int [][] sub = new int[a.length][a.length];
            //System.out.println("made it");
            int rows = a.length;
            int columns = a[0].length;
            for(int i = 0; i < rows; i++){
                for(int j = 0; j < columns;j++){
                    sub[i][j] = a[i][j] - b[i][j];
                }
            }
            return sub;
        }
    
         void C(int[][] product){
            for(int i = 0; i <product.length; i++){
                for(int j = 0; j < product[i].length; j++){
                    System.out.print(product[i][j]  + " ");
    
                }
                System.out.println();
            }
        }
    }