Search code examples
cudablascublas

Does the leading dimension in cuBLAS allow for accessing any submatrix?


I'm trying to understand the idea of the leading dimension in cuBLAS. It's mentioned that lda must always be greater than or equal to the # of rows in a matrix.

If I have a 100x100 matrix A and I wanted to access A(90:99, 0:99), what would be the arguments of cublasSetMatrix? lda specifies the number of rows between the elements in the same column(100 in this case), but where would I specify the 90? I can only see a way by adjusting *A.

The function definition is:

cublasStatus_t cublasSetMatrix(int rows, int cols, int elemSize, const void *A, int lda, void *B, int ldb)

And I'm also guessing that I wouldn't be able to transfer the bottom right 3x3 portion of a 5x5 matrix given the length limits.


Solution

  • You have to "adjust *A", as you called it. The pointer that is given to this function must be the starting entry of the respective sub-matrix.

    You did not say whether your matrix A is actually the input- or the output matrix, but this should not change much, conceptually.

    Assuming you have the following code:

    // The matrix in host memory
    int rowsA = 100;
    int colsA = 100;
    float *A = new float[rowsA*colsA];
    
    // Fill A with values
    ...
    
    // The sub-matrix that should be copied to the device.
    // The minimum index is INCLUSIVE
    // The maximum index is EXCLUSIVE
    int minRowA = 0;
    int maxRowA = 100;
    int minColA = 90;
    int maxColA = 100;
    int rowsB = maxRowA-minRowA;
    int colsB = maxColA-minColA;
    
    // Allocate the device matrix
    float *dB = nullptr;
    cudaMalloc(&dB, rowsB * colsB * sizeof(float));
    

    Then, for the cublasSetMatrix call, you have to compute the starting element of the source matrix:

    float *sourceA = A + (minRowA + minColA * rowsA);
    cublasSetMatrix(rowsB, colsB, sizeof(float), sourceA, rowsA, dB, rowsB);
    

    And this is where the 90 that you asked for comes into play: It is the minColA in the computation of the source pointer.