Search code examples

How to vectorize the following python code

I'm trying to obtain a matrix, where each element is calculated as follows:

X = torch.ones(batch_size, dim)
X_ = torch.ones(batch_size, dim)
Y = torch.ones(batch_size, dim)
M = torch.zeros(batch_size, batch_size)
for i in range(batch_size):
    for j in range(batch_size):
        M[i, j] = ((X[i] - X_[i] * Y[j])**2).sum()

It's very slow to calculate M element-wise, is there any suggestion about how to use matrix multiplication to replace the for loops?



  • If you want to sum() over dim, you can "lift" your 2D problem to 3D and sum there:

    M = ((X[:, None, :] - X_[:, None, :] * Y[None, ...])**2).sum(dim=2)

    How it works:

    X[:, None, :] and X_[:, None, :] are 3D of size (batch_size, 1, dim), and Y[None, ...] is of size (1, batch_size, dim).

    When multiplying X_[:, None, :] * Y[None, ...] pytorch broadcasts the dimensions of size 1 to the appropriate dimension to get a result of size (batch_size, batch_size, dim).
    Finally, you sum() only over the last dimension (dim=2) to get an output M of size (batch_size, batch_size).

    The trick here is done by taking advantage of broadcasting.