Search code examples
pythonnumpymatrix-multiplication

custom matrix multiplication with numpy


I want a strange dot product for matrix multiplication in numpy. For a line [1,2,3] of matrix A and a column [4,5,6] for matrix B, I wish to use the "product" min(1+4, 2+5, 3+6) for obtaining the matrix product AB.


Solution

  • In [498]: A = np.arange(12).reshape(4,3)                                             
    In [499]: B = np.arange(4,10).reshape(3,2)                                           
    In [500]: A                                                                          
    Out[500]: 
    array([[ 0,  1,  2],
           [ 3,  4,  5],
           [ 6,  7,  8],
           [ 9, 10, 11]])
    In [501]: B                                                                          
    Out[501]: 
    array([[4, 5],
           [6, 7],
           [8, 9]])
    

    Reference iterative solution:

    In [504]: res = np.zeros((A.shape[0],B.shape[1]), A.dtype) 
         ...: for i,row in enumerate(A): 
         ...:     for j,col in enumerate(B.T): 
         ...:         res[i,j] = np.min(row+col) 
         ...:                                                                            
    In [505]: res                                                                        
    Out[505]: 
    array([[ 4,  5],
           [ 7,  8],
           [10, 11],
           [13, 14]])
    

    Faster version using broadcasting:

    In [506]: np.min(A[:,:,None]+B[None,:,:], axis=1)                                    
    Out[506]: 
    array([[ 4,  5],
           [ 7,  8],
           [10, 11],
           [13, 14]])
    

    ===

    Demonstrate the equivalence to a matrix product:

    In [507]: np.dot(A,B)                                                                
    Out[507]: 
    array([[ 22,  25],
           [ 76,  88],
           [130, 151],
           [184, 214]])
    In [508]: np.sum(A[:,:,None]*B[None,:,:], axis=1)                                    
    Out[508]: 
    array([[ 22,  25],
           [ 76,  88],
           [130, 151],
           [184, 214]])