Search code examples
c++c++11matrixmatrix-multiplicationstrassen

"Splitting" a matrix in constant time


I am trying to implement Strassen's algorithm for matrix multiplication in C++, and I want to find a way to split two matrices into four parts each in constant time. Here is the current way I am doing so:

for(int i = 0; i < n; i++){
    for(int j = 0; j < n; j++){
        A11[i][j] = a[i][j];
        A12[i][j] = a[i][j+n];
        A21[i][j] = a[i+n][j];
        A22[i][j] = a[i+n][j+n];
        B11[i][j] = b[i][j];
        B12[i][j] = b[i][j+n];
        B21[i][j] = b[i+n][j];
        B22[i][j] = b[i+n][j+n];
    }
}

This approach is obviously O(n^2), and it adds n^2*log(n) to the runtime, as it is called for each recursive call.

It seems that the way to do this in constant time is to create pointers to the four sub-matrices, rather than copy over the values, but I am having a difficult time figuring out how to create those pointers. Any help would be appreciated.


Solution

  • Don't think of matrices, think of matrix views.

    A matrix view has pointer to a buffer of T, a width, a height, an offset, and a stride between columns (or rows).

    We can start with an array view type.

    template<class T>
    struct array_view {
      T* b = 0; T* e = 0;
      T* begin() const{ return b; }
      T* end() const{ return e; }
    
      array_view( T* s, T* f ):b(s), e(f) {}
      array_view( T* s, std::size_t l ):array_view(s, s+l) {}
    
      std::size_t size() const { return end()-begin(); }
      T& operator[]( std::size_t n ) const { return *(begin()+n); }
      array_view slice( std::size_t start, std::size_t length ) const {
        start = (std::min)(start, size());
        length = (std::min)(size()-start, length);
        return {b+start, length};
      }
    };
    

    Now our matrix view:

    temlpate<class T>
    struct matrix_view {
      std::size_t height, width;
      std::size_t offset, stride;
      array_view<T> buffer;
    
      // TODO: Ctors
      // one from a matrix that has offset and stirde set to 0.
      // another that lets you create a sub-matrix
      array_view<T> operator[]( std::size_t n ) const {
        return buffer.slice( offset+stride*n, width ); // or width, depending on if row or column major
      }
    };
    

    Now your code does work on matrix_views, not matrices.