Search code examples

how to vectorize the scatter-matmul operation

I have many matrices w1, w2, w3...wn with shapes (k*n1, k*n2, k*n3...k*nn) and x1, x2, x3...xn with shapes (n1*m, n2*m, n3*m...nn*m).
I want to get w1@x1, w2@x2, w3@x3 ... respectively.

The resulting matrix is multiple k*m matrices and can be concatenated into a large matrix with shape (k*n)*m.

Multiply them one by one will be slow. How to vectorize this operation?

Note: The input can be a k*(n1+n2+n3+...+nn) matrix and a (n1+n2+n3+...+nn)*m matrix, and we may use a batch index to indicate those submatrices.

This operation is related to the scatter operations implemented in pytorch_scatter, so I refer it as "scatter_matmul".


  • You can vectorize your operation by creating a large block-diagonal matrix W of shape n*kx(n1+..+nn) where the w_i matrices are the blocks on the diagonal. Then you can vertically stack all x matrices into an X matrix of shape (n1+..+nn)xm. Multiplying the block diagonal W with the vertical stack of all x matrices, X:

    Y = W @ X

    results with Y of shape (k*n)xm which is exactly the concatenated large matrix you are seeking.

    If the shape of the block diagonal matrix W is too large to fit into memory, you may consider making W sparse and compute the product using