This is a query regarding the internal working of torch.einsum
in the GPU. I know how to use einsum
. Does it perform all possible matrix multiplications, and just pick out the relevant ones, or does it perform only the required computation?
For example, consider two tensors a
and b
, of shape (N,P)
, and I wish to find the dot product of each corresponding tensor ni
, of shape (1,P)
.
Using einsum, the code is:
torch.einsum('ij,ij->i',a,b)
Without using einsum, another way to obtain the output is :
torch.diag(a @ b.t())
Now, the second code is supposed to perform significantly more computations than the first one (eg if N
= 2000
, it performs 2000
times more computation). However, when I try to time the two operations, they take roughly the same amount of time to complete, which begs the question. Does einsum
perform all combinations (like the second code), and picks out the relevant values?
Sample Code to test:
import time
import torch
for i in range(100):
a = torch.rand(50000, 256).cuda()
b = torch.rand(50000, 256).cuda()
t1 = time.time()
val = torch.diag(a @ b.t())
t2 = time.time()
val2 = torch.einsum('ij,ij->i',a,b)
t3 = time.time()
print(t2-t1,t3-t2, torch.allclose(val,val2))
It probably has to do with the fact that the GPU can parallelize the computation of a @ b.t()
. This means that the GPU doesn't actually have to wait for each row-column multiplication computation to finish to compute then next multiplication.
If you check on CPU then you see that torch.diag(a @ b.t())
is significantly slower than torch.einsum('ij,ij->i',a,b)
for large a
and b
.