Search code examples
matrix-multiplication

How does the tiled multiplication of two matrices work?


My teacher wrote:

Implement a CUDA kernel for matrix products as outer product vectors. In this version, each block of K threads counts a square piece of the result matrix of size KxK by implementing the matrix outer product formula. The kernel uses shared memory to store the corresponding column vectors from matrix A and matrix B and to store the corresponding fragment of the result array.

As far as I understand from this text, he wants me to do matrix multiplication using vector outer products, that also incorporate tiling. That's why I came up with this.

Suppose, I want to multiply the following matrices using kxk=2x2 tiles:

A = B = [[ 1  2  3  4]
         [ 5  6  7  8]
         [ 9 10 11 12]
         [13 14 15 16]]

The multiplication result would be:

enter image description here

Does tile multiplication work this way?

Or, am I missing something?

===========================================================================================
k   r   i   j           a[i][r]     b[r][j]       a[i][r]*b[r][j]  c[i][j]    
===========================================================================================           
2                       a[0][0]=1   b[0][0]=1     1*1=1            c[0][0]=1     
                        a[0][0]=1   b[0][1]=2     1*2=2            c[0][1]=2
                        a[1][0]=5   b[0][0]=1     5*1=5            c[1][0]=5
                        a[1][0]=5   b[0][1]=2     5*2=10           c[1][1]=10
—------------------------------------------------------------------------------------------              
                        a[0][1]=2   b[1][0]=5     2*5=10           c[0][0]=(1+10)=11 
                        a[0][1]=2   b[1][1]=6     2*6=12           c[0][1]=(2+12)=14
                        a[1][1]=6   b[1][0]=5     6*5=30           c[1][0]=(5+30)=35
                        a[1][1]=6   b[1][1]=6     6*6=36           c[1][1]=(10+36)=46
—------------------------------------------------------------------------------------------
2                       a[0][2]=3   b[2][0]=9     3*9=27           c[0][0]=(11+27)=38      
                        a[0][2]=3   b[2][1]=10    3*10=30          c[0][1]=(14+30)=44
                        a[1][2]=7   b[2][0]=9     7*9=63           c[1][0]=(35+63)=98
                        a[1][2]=7   b[2][1]=10    7*10=70          c[1][1]=(46+70)=116
—------------------------------------------------------------------------------------------              
                        a[0][3]=4   b[3][0]=13    4*13=52          c[0][0]=(38+52)=90 
                        a[0][3]=4   b[3][1]=14    4*14=56          c[0][1]=(44+56)=100
                        a[1][3]=8   b[3][0]=13    8*13=104         c[1][0]=(98+104)=202
                        a[1][3]=8   b[3][1]=14    8*14=112         c[1][1]=(116+112)=228
===========================================================================================
2                       a[0][0]=1   b[0][2]=3     1*3=3            c[0][2]=3     
                        a[0][0]=1   b[0][3]=4     1*4=4            c[0][3]=4
                        a[1][0]=5   b[0][2]=3     5*3=15           c[1][2]=15
                        a[1][0]=5   b[0][3]=4     5*4=20           c[1][3]=20
—------------------------------------------------------------------------------------------              
                        a[0][1]=2   b[1][2]=7     2*7=14           c[0][2]=(3+14)=17 
                        a[0][1]=2   b[1][3]=8     2*8=16           c[0][3]=(4+16)=20
                        a[1][1]=6   b[1][2]=7     6*7=42           c[1][2]=(15+42)=57
                        a[1][1]=6   b[1][3]=8     6*8=48           c[1][3]=(20+48)=68
—------------------------------------------------------------------------------------------
2                       a[0][2]=3   b[2][2]=11    3*11=33          c[0][2]=(17+33)=50      
                        a[0][2]=3   b[2][3]=12    3*12=36          c[0][3]=(20+36)=56
                        a[1][2]=7   b[2][2]=11    7*11=77          c[1][2]=(57+77)=134
                        a[1][2]=7   b[2][3]=12    7*12=84          c[1][3]=(68+84)=152
—------------------------------------------------------------------------------------------              
                        a[0][3]=4   b[3][2]=15    4*15=60          c[0][2]=(50+60)=110 
                        a[0][3]=4   b[3][3]=16    4*16=64          c[0][3]=(56+64)=120
                        a[1][3]=8   b[3][2]=15    8*15=120         c[1][2]=(134+120)=254
                        a[1][3]=8   b[3][3]=16    8*16=128         c[1][3]=(152+128)=280
===========================================================================================
2                       a[2][0]=9   b[0][0]=1     9*1=9             c[2][0]=(0+9)=9     
                        a[2][0]=9   b[0][1]=2     9*2=18            c[2][1]=(0+18)=18
                        a[3][0]=13  b[0][0]=1     13*1=13           c[3][0]=(0+13)=13
                        a[3][0]=13  b[0][1]=2     13*2=26           c[3][1]=(0+26)=26
—------------------------------------------------------------------------------------------  
                        a[2][1]=10  b[1][0]=5     10*5=50           c[2][0]=(9+50)=59     
                        a[2][1]=10  b[1][1]=6     10*6=60           c[2][1]=(18+60)=78
                        a[3][1]=14  b[1][0]=5     14*5=70           c[3][0]=(13+70)=83
                        a[3][1]=14  b[1][1]=6     14*6=84           c[3][1]=(26+84)=110
—------------------------------------------------------------------------------------------ 
2                       a[2][2]=11  b[2][0]=9     11*9=99           c[2][0]=(59+99)=158     
                        a[2][2]=11  b[2][1]=10    11*10=110         c[2][1]=(78+110)=198
                        a[3][2]=15  b[2][0]=9     15*9=135          c[3][0]=(83+135)=218
                        a[3][2]=15  b[2][1]=10    15*10=150         c[3][1]=(110+150)=260
—------------------------------------------------------------------------------------------
2                       a[2][3]=12  b[3][0]=13    12*13=156         c[2][0]=(158+156)=314     
                        a[2][3]=12  b[3][1]=14    12*14=168         c[2][1]=(188+168)=356 
                        a[3][3]=16  b[3][0]=13    16*13=208         c[3][0]=(218+208)=426
                        a[3][3]=16  b[3][1]=14    16*14=224         c[3][1]=(260+224)=484 
===========================================================================================
2                       a[2][0]=9   b[0][2]=3     9*3=27            c[2][2]=(0+27)=27     
                        a[2][0]=9   b[0][3]=4     9*4=36            c[2][3]=(0+36)=36
                        a[3][0]=13  b[0][2]=3     13*3=39           c[3][2]=(0+39)=39
                        a[3][0]=13  b[0][3]=4     13*4=52           c[3][3]=(0+52)=52
—------------------------------------------------------------------------------------------  
                        a[2][1]=10  b[1][2]=7     10*7=70           c[2][2]=(27+70)=97     
                        a[2][1]=10  b[1][3]=8     10*8=80           c[2][3]=(36+80)=116
                        a[3][1]=14  b[1][2]=7     14*7=98           c[3][2]=(39+98)=137
                        a[3][1]=14  b[1][3]=8     14*8=112          c[3][3]=(52+112)=164
—------------------------------------------------------------------------------------------ 
2                       a[2][2]=11  b[2][2]=11    11*11=121         c[2][2]=(97+121)=218     
                        a[2][2]=11  b[2][3]=12    11*12=132         c[2][3]=(116+132)=248
                        a[3][2]=15  b[2][2]=11    15*11=165         c[3][2]=(137+165)=302
                        a[3][2]=15  b[2][3]=12    15*12=180         c[3][3]=(164+180)=344
—------------------------------------------------------------------------------------------
2                       a[2][3]=12  b[3][2]=15    12*15=180         c[2][2]=(218+180)=398     
                        a[2][3]=12  b[3][3]=16    12*16=192         c[2][3]=(248+192)=440 
                        a[3][3]=16  b[3][2]=15    16*15=240         c[3][2]=(302+240)=542
                        a[3][3]=16  b[3][3]=16    16*16=256         c[3][3]=(344+256)=600   
===========================================================================================

Solution

  • In general a “tiled” matrix multiplication implies restructuring the matrix dot product into the product of block matrices, so that A dot B could (as one example) be expressed as:

    
    +——--+—-——+     +——--+—-——+     +——————————---+—-—————————-—+
    | A1 | A2 |     | B1 | B2 |     | A1.B1+A2.B3 | A1.B2+A2.B4 |
    +——--+—-——+ dot +——--+—-——+  =  +——————————---+—-—————————-—+
    | A3 | A4 |     | B3 | B4 |     | A3.B1+A4.B3 | A3.B2+A4.B4 |
    +——--+—-——+     +——--+—-——+     +——————————---+—-—————————-—+
    
    

    For the 4x4 case you are asking about, this 2x2 block structure implies that each sub-matrix in A and B are 2x2.

    If you choose to perform each sub-matrix product by outer product expansion, rather than the set of inner products you already know about, you do that as follows, using A1.B1 as an example:

    A1.B1 = sum(outer(A1[:,1],B1[1,:]), outer(A1[:,2],B1[2,:], ....,
                outer(A1[:,N],B1[N,:]))
    

    which is

    outer([a11 a21],[b11,b12]) + outer([a12 a22],[b21,b22])

    or

    | a11*b11 a11*b12 | + | a12*b21 a12*b22 | = | a11*b11+a12*b21 a11*b12+a12*b22 |
    | a21*b11 a21*b12 |   | a22*b21 a22*b22 |   | a21*b11+a22*b21 a21*b12+a22*b22 |
    

    for the 2x2 case.

    It should be trivial to confirm that the terms in the RHS of the result are identical to that you would obtain by calculating the set of inner products of the rows and columns of the two sub-matrices.

    You repeat this process for the other seven sub-matrix products and accumulate the results to yield the complete matrix multiplication.