Search code examples
algorithmsubroutinestrassen

strassen matrix multiplication


Well, it's a question from 《introduction to algorithms》 whose number is 4.2-6. It's described like this:

How quickly can you multiply a kn*n matrix by an n*kn matrix, using Strassen's algorithm as a subroutine?

I'm thinking of expending both two matrix to kn*kn matrix, then I can apply Strassen's algorithm to this question. But I will get a Math.pow(kn, lg7) running time.

Does anybody have a better solution. Happy new year to everyone.


Solution

  • Another Vector based Implementation of Strasens's Algorithm is here , it shows comparision in running times of both naive as well as strssens :

    enter code here:
    #include <cstdio>
    #include <iostream>
    #include <cstdlib>
    #include <ctime>
    #include <cassert>
    #include <vector>
    #include <ctime>
    using namespace std;
    void fun(vector<vector<int> >& u , vector<vector<int> >&m , int P , int n)
    {
    
    
        for(int i = 0 ; i < n ; i++)
        {
            vector<int>t ;
            for(int j = 0 ; j < n ; j++)
            {
                          switch(P)
              {
                            case 1:
                            {
                   t.push_back(u[i][j]);
                               break;
                }
                            case 2:
                            {
                    t.push_back(u[i][j+n]);
                                break;
                    }
                            case 3:
                            {
                    t.push_back(u[i+n][j]);
                                break;
                }
                            case 4:
                            {
                t.push_back(u[i+n][j+n]);
                                break;
                }
                         }
                      }
    
                      m[i] = t;
        }
    }
    void normalmul(int n , vector< vector<int> >& u   , vector< vector<int> >& v  ,     vector< vector<int> >& z )
    
    {
    for(int i = 0 ; i < n ; i++)
    {
        for(int j = 0 ; j < n ; j++)
        {
            z[i][j] = 0;
            for(int k = 0 ; k < n ; k++)
            {
                z[i][j] += (u[i][k] * v[k][j]);
            }
        }
    }
    }
    
    void strassen(int n , vector< vector<int> >& u   , vector< vector<int> >& v  , vector< vector<int> >& z)
    
    {
    if(n == 32)
    {
        normalmul(n,u,v,z);
        return;
    }
    else
    {
        int Shiftt = n>>1;
        vector<vector<int> >AA(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >BB(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >CC(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >DD(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >EE(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >FF(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >GG(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >HH(Shiftt , vector<int>(Shiftt));
    
        vector<vector<int> >A1(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >A2(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >A3(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >A4(Shiftt , vector<int>(Shiftt));
        fun(u,AA,1,n>>1);
        fun(u,BB,2,n>>1);
        fun(u,CC,3,n>>1);
        fun(u,DD,4,n>>1);
        fun(v,EE,1,n>>1);
        fun(v,FF,2,n>>1);
        fun(v,GG,3,n>>1);
        fun(v,HH,4,n>>1);
        vector<vector<int> >M1(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >M2(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >M3(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >M4(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >M5(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >M6(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >M7(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >T1(Shiftt , vector<int>(Shiftt));
        vector<vector<int> >T2(Shiftt , vector<int>(Shiftt));
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                T1[i][j] = AA[i][j] + DD[i][j];
                T2[i][j] = EE[i][j] + HH[i][j];
            }
        }
        strassen(Shiftt,T1,T2,M1);
    
    
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                T1[i][j] = CC[i][j] - AA[i][j];
                T2[i][j] = EE[i][j] + FF[i][j];
            }
        }
        strassen(Shiftt,T1,T2,M6);
    
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                T1[i][j] = BB[i][j] - DD[i][j];
                T2[i][j] = GG[i][j] + HH[i][j];
            }
        }
        strassen(Shiftt,T1,T2,M7);
    
    
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                T1[i][j] = CC[i][j] + DD[i][j];
                T2[i][j] = EE[i][j] ;
            }
        }
        strassen(Shiftt,T1,T2,M2);
    
    
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                T1[i][j] = AA[i][j] ;
                T2[i][j] = FF[i][j] - HH[i][j];
            }
        }
        strassen(Shiftt,T1,T2,M3);
    
    
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                T1[i][j] = DD[i][j];
                T2[i][j] = GG[i][j] - EE[i][j];
            }
        }
        strassen(Shiftt,T1,T2,M4);
    
    
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                T1[i][j] = AA[i][j] + BB[i][j];
                T2[i][j] = HH[i][j];
            }
        }
        strassen(Shiftt,T1,T2,M5);
    
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                A1[i][j] = M1[i][j] + M4[i][j] - M5[i][j] + M7[i][j] ;
                A2[i][j] = M3[i][j] + M5[i][j] ;
                A3[i][j] = M2[i][j] + M4[i][j] ;
                A4[i][j] = M1[i][j] - M2[i][j] + M3[i][j] + M6[i][j] ;
            }
        }
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                z[i][j] = A1[i][j];
            }
        }
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                z[i][j+Shiftt] = A2[i][j];
            }
        }
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                z[i+Shiftt][j] = A3[i][j];
            }
        }
        for(int i = 0 ; i < Shiftt ; i++)
        {
            for(int j = 0 ; j < Shiftt ; j++)
            {
                z[i+Shiftt][j+Shiftt] = A4[i][j];
            }
        }
    }
    }
    
    
    int main()
    {
    int t,n;
    freopen("input_file.txt","r",stdin);
    cin >> t;
    while(t--)
    {
        int vl ;
        scanf("%d",&n);
        cout <<  "value of n " << n  << endl ;;
        vector< vector<int> >u(n,vector<int>(n));
        vector< vector<int> >v(n,vector<int>(n));
        vector< vector<int> >z(n,vector<int>(n));
        vector< vector<int> >zz(n,vector<int>(n));
        vector<int> temp;
        for(int i = 0 ; i < n ; i++)
        {
                vector<int> temp;
            for(int j = 0 ; j < n ; j++)
            {
                scanf("%d",&vl);
                temp.push_back(vl);
            }
            u[i] = temp;
        }
        for(int i = 0 ; i < n ; i++)
        {
            vector<int> temp;
            for(int j = 0 ; j < n ; j++)
            {
                scanf("%d",&vl);
                temp.push_back(vl);
            }
            v[i] = temp;
        }
        clock_t start , end ;
    
        //USING NAIVE APPROACH
    
        start = clock();
                cout<<"Traditional Algorithm Running Time : ";
        normalmul(n,u,v,z);
    
        end = clock() ;
    
        cout<<(double)(end-start)/CLOCKS_PER_SEC<<" seconds"<<endl ;
    
    
        /*cout << "ANSWER OF MULTIPLICATION BY NAIVE APPROACH" << endl ;
        for(int i = 0 ; i < n ; i++)
        {
            for(int j = 0 ; j  < n ; j++)
            {
                cout << z[i][j] << " ";
            }
            cout << endl ;
        }*/
    
    
        //USING STRASSENS ALGORITHM 
    
        start = clock() ;
    
        strassen(n,u,v,zz);
    
        end = clock();
                cout<<"Strassen Algorithm Running Time : ";
        cout<<(double)(end-start)/CLOCKS_PER_SEC<<" seconds"<<endl ;
    
        /*cout << "ANSWER BY STRASSENS ALGORITHM " << endl ;
        for(int i = 0 ; i < n ; i++)
        {
            for(int j = 0 ; j  < n ; j++)
            {
                cout << zz[i][j] << " ";
            }
            cout << endl ;
        }*/
    }
    return 0;
        */  IPG_2011006   Abhishek Yadav */
    }