Search code examples
pythonnumpymatrixrandom

How to easily perform this random matrix multiplication with numpy?


I want to produce 2 random 3x4 matrices where the entries are normally distributed, A and B. After that, I have a 2x2 matrix C = [[a,b][c,d]], and I would like to use it to produce 2 new 3x4 matrices A' and B', where A' = a A + b B, B' = c A + d B.

In order to produce the matrices A and B, I was thinking to use this line of code:

Z = np.random.normal(0.0, 1.0, [2,3, 4])

But, given the matrix C, I don't know how to use simple Numpy vectorization to achieve the matrices A' and B' or, equivalently, a 2x3x4 array containing A' and B'. Any idea?


Solution

  • I think you can use np.einsum

    np.einsum("ij, jkl -> ikl", C, Z)
    

    where "ij, jkl -> ikl" specifies the contraction pattern, where i and j are the indices of the C matrix, and j, k, and l are the indices of the Z array.


    Example

    Given dummy data like below

    np.random.seed(0)
    Z = np.random.normal(0.0, 1.0, [2, 3, 4])
    C = [[1,2],[3,4]]
    

    You will see

    print("AB_prim(einsum): \n", np.einsum("ij, jkl -> ikl", C, Z))
    

    shows

    AB_prim(einsum): 
     [[[  3.2861278    0.64350724   1.86646445   2.90824185]
      [  4.85571614  -1.38759441   1.57622382  -1.85954869]
      [ -5.20919848   1.71783569   1.87291597  -0.03005653]]
    
     [[  8.33630794   1.68717169   4.71166688   8.05737691]
      [ 11.57899026  -3.75246669   4.10253606  -3.87045458]
      [-10.52161582   3.84626989   3.88987551   1.39416044]]]
    

    and

    A, B = Z[0], Z[1]
    print("A_prim: \n", C[0][0] * A + C[0][1] * B)
    print("B_prim: \n", C[1][0] * A + C[1][1] * B)
    

    shows

    A_prim: 
     [[ 3.2861278   0.64350724  1.86646445  2.90824185]
     [ 4.85571614 -1.38759441  1.57622382 -1.85954869]
     [-5.20919848  1.71783569  1.87291597 -0.03005653]]
    B_prim: 
     [[  8.33630794   1.68717169   4.71166688   8.05737691]
     [ 11.57899026  -3.75246669   4.10253606  -3.87045458]
     [-10.52161582   3.84626989   3.88987551   1.39416044]]