Search code examples
pythonarray-broadcasting

broadcasting for matrix multiplication


Consider the following

np.random.seed(2)

result = SC @ x

SC is nn x nn and x is nn x ns.

now consider we have a 3D SCs ns x nn x nn.

ns = 4
nn = 2

SCs = np.random.rand(ns, nn, nn)
x = np.random.rand(nn, ns)

def matmul3d(a, b):
    ns, nn, nn = a.shape
    assert(b.shape == (nn, ns))
    
    results = np.zeros((nn, ns))
    for i in range(ns):
        results[:, i] = a[i, :, :] @ b[:, i]
    return results
array([[0.385428  , 0.22932766, 0.36791082, 0.06029485],
       [0.68934311, 0.14157493, 0.75236553, 0.09049892]])

simply use matrix multiplication, the diagonal is the result:

results = a @ b

array([[[0.385428  , 0.21717737, 0.38019609, 0.0372277 ],
        [0.68934311, 0.30008412, 0.65169432, 0.0858002 ]],

       [[0.52588409, 0.22932766, 0.4972909 , 0.06536792],
        [0.48764911, 0.14157493, 0.43837138, 0.07607813]],

       [[0.39071113, 0.1655206 , 0.36791082, 0.04962322],
        [0.79777992, 0.34153306, 0.75236553, 0.10054907]],

       [[0.37441129, 0.10004409, 0.33380446, 0.06029485],
        [0.5542946 , 0.14242876, 0.4923592 , 0.09049892]]])

Is there any broadcasting for this to remove the loop?


Solution

  • You can do it via einsum, you basically just have to provide the array signature:

    np.einsum('ijk,ki->ji', a, b)