Search code examples
calgorithmstrassen

Strassen's multiplication in C


Please have a look in the following code:

#include<stdio.h>
#include<stdlib.h>

int **divide(int **Matrix,int n,int position)
{
    int i,j;
    int **Partition=malloc(sizeof(*Partition)*n);
    for(i=0;i<n;i++)
    {
        Partition[i]=calloc(n,sizeof(*Partition[i]));
    }
    if(position==1)
    {
        for(i=0;i<n/2;i++)
        {
            for(j=0;j<n/2;j++)
            {
                Partition[i][j]=Matrix[i][j];
            }
        }
    }
    else if(position==2)
    {
        for(i=0;i<n/2;i++)
        {
            for(j=0;j<n/2;j++)
            {
                Partition[i][j]=Matrix[i][j+n/2];
            }
        }
    }
    else if(position==3)
    {
        for(i=0;i<n/2;i++)
        {
            for(j=0;j<n/2;j++)
            {
                Partition[i][j]=Matrix[i+n/2][j];
            }
        }
    }
    else if(position==4)
    {
        for(i=0;i<n/2;i++)
        {
            for(j=0;j<n/2;j++)
            {
                Partition[i][j]=Matrix[i+n/2][j+n/2];
            }
        }
    }
    return Partition;
}


int **allocate(int n)
{
    int **newmatrix=malloc(sizeof(*newmatrix)*n);
    for(int i=0;i<n;i++)
    {
        newmatrix[i]=calloc(n, sizeof(*newmatrix[i]));
    }
    return newmatrix;
}
void mfree(int **matrix,int n) {
    for (int i=0;i<n;i++) {
        free(matrix[i]);
    }
    free(matrix);
}
int **add(int **a,int **b,int n)
{
    int **c=allocate(n);
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            c[i][j]=a[i][j]+b[i][j];
        }
    }
    return c;
}
int **subtract(int **a,int **b,int n)
{
    int **c=allocate(n);
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            c[i][j]=a[i][j]-b[i][j];
        }
    }
    return c;
}
void print(int **Matrix,int n)
{
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            printf("%d ",Matrix[i][j]);
        }
        printf("\n");
    }
}

int **Strassens(int **A,int **B,int n)
{
    int **C=allocate(n);
    if(n==1)
    {
        C[0][0]=A[0][0]*B[0][0];
    }
    else
    {   //Allocate the submatrices
        int **a11=allocate(n/2);
        int **a12=allocate(n/2);
        int **a21=allocate(n/2);
        int **a22=allocate(n/2);

        int **b11=allocate(n/2);
        int **b12=allocate(n/2);
        int **b21=allocate(n/2);
        int **b22=allocate(n/2);


        a11=divide(A,n,1);
        a12=divide(A,n,2);
        a21=divide(A,n,3);
        a22=divide(A,n,4);

        b11=divide(B,n,1);
        b12=divide(B,n,2);
        b21=divide(B,n,3);
        b22=divide(B,n,4);

        
        int **s1=subtract(b12,b22,n/2);
        int **s2=add(a11,a12,n/2);
        int **s3=add(a21,a22,n/2);
        int **s4=subtract(b21,b11,n/2);
        int **s5=add(a11,a22,n/2);
        int **s6=add(b11,b22,n/2);
        int **s7=subtract(a12,a22,n/2);
        int **s8=add(b21,b22,n/2);
        int **s9=subtract(a11,a21,n/2);
        int **s10=add(b11,a12,n/2);

        int **p1=Strassens(a11,s1,n/2);
        int **p2=Strassens(s2,b22,n/2);
        int **p3=Strassens(s3,b11,n/2);
        int **p4=Strassens(a22,s4,n/2);
        int **p5=Strassens(s5,s6,n/2);
        int **p6=Strassens(s7,s8,n/2);
        int **p7=Strassens(s9,s10,n/2);


        int **c11=subtract(add(p5,p4,n/2),add(p2,p6,n/2),n/2);
        int **c12=add(p1,p2,n/2);
        int **c21=add(p3,p4,n/2);
        int **c22=subtract(add(p5,p1,n/2),subtract(p3,p7,n/2),n/2);

        for(int i=0;i<n/2;i++)
        {
            for(int j=0;j<n/2;j++)
            {
                C[i][j]=c11[i][j];
                C[i][j+n/2]=c12[i][j];
                C[i+n/2][j]=c21[i][j];
                C[i+n/2][j+n/2]=c22[i][j];
            }
        }
    }
    return C;
}

int main()
{
    int n=8;  //Dimension of the square matrix,  n*n;
    int **A=allocate(n);
    int **B=allocate(n);
    for(int i=0;i<n;i++)
    {
        for(int j=0;j<n;j++)
        {
            A[i][j]=j+1;
            B[i][j]=j+1;
        }
    }
    printf("Matrix A:\n");
    print(A,n);
    printf("Matrix B: \n");
    print(B,n);
    printf("\n...Performing Multiplication with Strassen's...\nMatrix A*B:\n");
    int **C = Strassens(A,B,n);
    print(C,n);
    mfree(C,n);
}

I know it's a very silly question to ask, there's some problem with the math. But I'm not able to get where am I going wrong. The problem is that, when I multiply two matrices with equal values I get the desired result, but that doesn't applies for matrices with different values. For example, have a look at the outputs:

Matrix A:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
Matrix B:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8

...Performing Multiplication with Strassen's...
Matrix A*B:
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288
36 72 108 144 180 216 252 288

and

Matrix A:
1 2 3 4 5 6 7 8
2 3 4 5 6 7 8 9
3 4 5 6 7 8 9 10
4 5 6 7 8 9 10 11
5 6 7 8 9 10 11 12
6 7 8 9 10 11 12 13
7 8 9 10 11 12 13 14
8 9 10 11 12 13 14 15
Matrix B:
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8
1 2 3 4 5 6 7 8

...Performing Multiplication with Strassen's...
Matrix A*B:
316 424 484 528 460 440 372 288
300 398 452 426 412 366 308 154
268 360 414 446 348 312 246 134
252 254 382 424 300 126 182 112
156 232 260 272 404 352 252 136
140 150 228 34 356 334 188 138
108 168 70 54 292 224 246 118
92 -122 38 24 244 222 182 104

Solution

  • :_) Sorry for this. There was a slight math error in this part:

    int **c11=subtract(add(p5,p4,n/2),add(p2,p6,n/2),n/2);
    int **c12=add(p1,p2,n/2);
    int **c21=add(p3,p4,n/2);
    int **c22=subtract(add(p5,p1,n/2),subtract(p3,p7,n/2),n/2);
    

    Replacing c11 and c22 with

    int **c11=subtract(add(add(p5,p4,n/2),p6,n/2),p2,n/2);
    ...
    int **c22=subtract(subtract(add(p5,p1,n/2),p3,n/2),p7,n/2);
    

    corrects the math error.