I'm trying to obtain a matrix, where each element is calculated as follows:
X = torch.ones(batch_size, dim)
X_ = torch.ones(batch_size, dim)
Y = torch.ones(batch_size, dim)
M = torch.zeros(batch_size, batch_size)
for i in range(batch_size):
for j in range(batch_size):
M[i, j] = ((X[i] - X_[i] * Y[j])**2).sum()
It's very slow to calculate M
element-wise, is there any suggestion about how to use matrix multiplication to replace the for loops?
Thanks.
If you want to sum()
over dim, you can "lift" your 2D problem to 3D and sum there:
M = ((X[:, None, :] - X_[:, None, :] * Y[None, ...])**2).sum(dim=2)
How it works:
X[:, None, :]
and X_[:, None, :]
are 3D of size (batch_size, 1, dim)
, and Y[None, ...]
is of size (1, batch_size, dim)
.
When multiplying X_[:, None, :] * Y[None, ...]
pytorch broadcasts the dimensions of size 1 to the appropriate dimension to get a result of size (batch_size, batch_size, dim)
.
Finally, you sum()
only over the last dimension (dim=2)
to get an output M
of size (batch_size, batch_size)
.
The trick here is done by taking advantage of broadcasting.