I have come across a code which uses torch.einsum
to compute a tensor multiplication. I am able to understand the workings for lower order tensors, but, not for the 4D tensor as below:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])
I need help regarding:
torch.einsum
actually beneficial in this scenario?(Skip to the tl;dr section if you just want the breakdown of steps involved in an einsum)
I'll try to explain how einsum
works step by step for this example but instead of using torch.einsum
, I'll be using numpy.einsum
(documentation), which does exactly the same but I am just, in general, more comfortable with it. Nonetheless, the same steps happen for torch as well.
Let's rewrite the above code in NumPy -
import numpy as np
a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape
#(3, 2, 5, 4)
Einsum is composed of 3 steps: multiply
, sum
and transpose
Let's look at our dimensions. We have a (3, 5, 2, 10)
and a (3, 4, 2, 10)
that we need to bring to (3, 2, 5, 4)
based on 'nxhd,nyhd->nhxy'
Let's not worry about the order in which the n,x,y,h,d
axes is, and just worry about the fact if you want to keep them or remove (reduce) them. Writing them down as a table and see how we can arrange our dimensions -
## Multiply ##
n x y h d
--------------------
a -> 3 5 2 10
b -> 3 4 2 10
c1 -> 3 5 4 2 10
To get the broadcasting multiplication between x
and y
axis to result in (x, y)
, we will have to add a new axis at the right places and then multiply.
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)
c1 = a1*b1
c1.shape
#(3, 5, 4, 2, 10) #<-- (n, x, y, h, d)
Next, we want to reduce the last axis 10. This will get us the dimensions (n,x,y,h)
.
## Reduce ##
n x y h d
--------------------
c1 -> 3 5 4 2 10
c2 -> 3 5 4 2
This is straightforward. Lets just do np.sum
over the axis=-1
c2 = np.sum(c1, axis=-1)
c2.shape
#(3,5,4,2) #<-- (n, x, y, h)
The last step is rearranging the axis using a transpose. We can use np.transpose
for this. np.transpose(0,3,1,2)
basically brings the 3rd axis after the 0th axis and pushes the 1st and 2nd. So, (n,x,y,h)
becomes (n,h,x,y)
c3 = c2.transpose(0,3,1,2)
c3.shape
#(3,2,5,4) #<-- (n, h, x, y)
Let's do a final check and see if c3 is the same as the c which was generated from the np.einsum
-
np.allclose(c,c3)
#True
Thus, we have implemented the 'nxhd , nyhd -> nhxy'
as -
input -> nxhd, nyhd
multiply -> nxyhd #broadcasting
sum -> nxyh #reduce
transpose -> nhxy
Advantage of np.einsum
over the multiple steps taken, is that you can choose the "path" that it takes to do the computation and perform multiple operations with the same function. This can be done by optimize
paramter, which will optimize the contraction order of an einsum expression.
A non-exhaustive list of these operations, which can be computed by einsum
, is shown below along with examples:
numpy.trace
.numpy.diag
.numpy.sum
.numpy.transpose
.numpy.matmul
numpy.dot
.numpy.inner
numpy.outer
.numpy.multiply
.numpy.tensordot
.numpy.einsum_path
.%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
It shows that np.einsum
does the operation faster than individual steps.