pythonnumpypytorchtensornumpy-einsum

# Understanding PyTorch einsum

I'm familiar with how `einsum` works in NumPy. A similar functionality is also offered by PyTorch: torch.einsum(). What are the similarities and differences, either in terms of functionality or performance? The information available at PyTorch documentation is rather scanty and doesn't provide any insights regarding this.

Solution

• Since the description of einsum is skimpy in torch documentation, I decided to write this post to document, compare and contrast how `torch.einsum()` behaves when compared to `numpy.einsum()`.

Differences:

• NumPy allows both small case and capitalized letters `[a-zA-Z]` for the "subscript string" whereas PyTorch allows only the small case letters `[a-z]`.

• NumPy accepts nd-arrays, plain Python lists (or tuples), list of lists (or tuple of tuples, list of tuples, tuple of lists) or even PyTorch tensors as operands (i.e. inputs). This is because the operands have only to be array_like and not strictly NumPy nd-arrays. On the contrary, PyTorch expects the operands (i.e. inputs) strictly to be PyTorch tensors. It will throw a `TypeError` if you pass either plain Python lists/tuples (or its combinations) or NumPy nd-arrays.

• NumPy supports lot of keyword arguments (for e.g. `optimize`) in addition to `nd-arrays` while PyTorch doesn't offer such flexibility yet.

Here are the implementations of some examples both in PyTorch and NumPy:

``````# input tensors to work with

In [16]: vec
Out[16]: tensor([0, 1, 2, 3])

In [17]: aten
Out[17]:
tensor([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]])

In [18]: bten
Out[18]:
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]])
``````

1) Matrix multiplication
PyTorch: `torch.matmul(aten, bten)` ; `aten.mm(bten)`
NumPy : `np.einsum("ij, jk -> ik", arr1, arr2)`

``````In [19]: torch.einsum('ij, jk -> ik', aten, bten)
Out[19]:
tensor([[130, 130, 130, 130],
[230, 230, 230, 230],
[330, 330, 330, 330],
[430, 430, 430, 430]])
``````

2) Extract elements along the main-diagonal
PyTorch: `torch.diag(aten)`
NumPy : `np.einsum("ii -> i", arr)`

``````In [28]: torch.einsum('ii -> i', aten)
Out[28]: tensor([11, 22, 33, 44])
``````

3) Hadamard product (i.e. element-wise product of two tensors)
PyTorch: `aten * bten`
NumPy : `np.einsum("ij, ij -> ij", arr1, arr2)`

``````In [34]: torch.einsum('ij, ij -> ij', aten, bten)
Out[34]:
tensor([[ 11,  12,  13,  14],
[ 42,  44,  46,  48],
[ 93,  96,  99, 102],
[164, 168, 172, 176]])
``````

4) Element-wise squaring
PyTorch: `aten ** 2`
NumPy : `np.einsum("ij, ij -> ij", arr, arr)`

``````In [37]: torch.einsum('ij, ij -> ij', aten, aten)
Out[37]:
tensor([[ 121,  144,  169,  196],
[ 441,  484,  529,  576],
[ 961, 1024, 1089, 1156],
[1681, 1764, 1849, 1936]])
``````

General: Element-wise `nth` power can be implemented by repeating the subscript string and tensor `n` times. For e.g., computing element-wise 4th power of a tensor can be done using:

``````# NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr)
In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten)
Out[38]:
tensor([[  14641,   20736,   28561,   38416],
[ 194481,  234256,  279841,  331776],
[ 923521, 1048576, 1185921, 1336336],
[2825761, 3111696, 3418801, 3748096]])
``````

5) Trace (i.e. sum of main-diagonal elements)
PyTorch: `torch.trace(aten)`
NumPy einsum: `np.einsum("ii -> ", arr)`

``````In [44]: torch.einsum('ii -> ', aten)
Out[44]: tensor(110)
``````

6) Matrix transpose
PyTorch: `torch.transpose(aten, 1, 0)`
NumPy einsum: `np.einsum("ij -> ji", arr)`

``````In [58]: torch.einsum('ij -> ji', aten)
Out[58]:
tensor([[11, 21, 31, 41],
[12, 22, 32, 42],
[13, 23, 33, 43],
[14, 24, 34, 44]])
``````

7) Outer Product (of vectors)
PyTorch: `torch.ger(vec, vec)`
NumPy einsum: `np.einsum("i, j -> ij", vec, vec)`

``````In [73]: torch.einsum('i, j -> ij', vec, vec)
Out[73]:
tensor([[0, 0, 0, 0],
[0, 1, 2, 3],
[0, 2, 4, 6],
[0, 3, 6, 9]])
``````

8) Inner Product (of vectors) PyTorch: `torch.dot(vec1, vec2)`
NumPy einsum: `np.einsum("i, i -> ", vec1, vec2)`

``````In [76]: torch.einsum('i, i -> ', vec, vec)
Out[76]: tensor(14)
``````

9) Sum along axis 0
PyTorch: `torch.sum(aten, 0)`
NumPy einsum: `np.einsum("ij -> j", arr)`

``````In [85]: torch.einsum('ij -> j', aten)
Out[85]: tensor([104, 108, 112, 116])
``````

10) Sum along axis 1
PyTorch: `torch.sum(aten, 1)`
NumPy einsum: `np.einsum("ij -> i", arr)`

``````In [86]: torch.einsum('ij -> i', aten)
Out[86]: tensor([ 50,  90, 130, 170])
``````

11) Batch Matrix Multiplication
PyTorch: `torch.bmm(batch_tensor_1, batch_tensor_2)`
NumPy : `np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)`

``````# input batch tensors to work with
In [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)
In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4)

In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)
Out[15]:
tensor([[[  20,   23,   26,   29],
[  56,   68,   80,   92],
[  92,  113,  134,  155],
[ 128,  158,  188,  218]],

[[ 632,  671,  710,  749],
[ 776,  824,  872,  920],
[ 920,  977, 1034, 1091],
[1064, 1130, 1196, 1262]]])

# sanity check with the shapes
In [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape
Out[16]: torch.Size([2, 4, 4])

# batch matrix multiply using einsum
In [17]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
Out[17]:
tensor([[[  20,   23,   26,   29],
[  56,   68,   80,   92],
[  92,  113,  134,  155],
[ 128,  158,  188,  218]],

[[ 632,  671,  710,  749],
[ 776,  824,  872,  920],
[ 920,  977, 1034, 1091],
[1064, 1130, 1196, 1262]]])

# sanity check with the shapes
In [18]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2).shape
``````

12) Sum along axis 2
PyTorch: `torch.sum(batch_ten, 2)`
NumPy einsum: `np.einsum("ijk -> ij", arr3D)`

``````In [99]: torch.einsum("ijk -> ij", batch_ten)
Out[99]:
tensor([[ 50,  90, 130, 170],
[  4,   8,  12,  16]])
``````

13) Sum all the elements in an nD tensor
PyTorch: `torch.sum(batch_ten)`
NumPy einsum: `np.einsum("ijk -> ", arr3D)`

``````In [101]: torch.einsum("ijk -> ", batch_ten)
Out[101]: tensor(480)
``````

14) Sum over multiple axes (i.e. marginalization)
PyTorch: `torch.sum(arr, dim=(dim0, dim1, dim2, dim3, dim4, dim6, dim7))`
NumPy: `np.einsum("ijklmnop -> n", nDarr)`

``````# 8D tensor
In [103]: nDten = torch.randn((3,5,4,6,8,2,7,9))
In [104]: nDten.shape
Out[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])

# marginalize out dimension 5 (i.e. "n" here)
In [111]: esum = torch.einsum("ijklmnop -> n", nDten)
In [112]: esum
Out[112]: tensor([  98.6921, -206.0575])

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))

In [115]: torch.allclose(tsum, esum)
Out[115]: True
``````

15) Double Dot Products / Frobenius inner product (same as: torch.sum(hadamard-product) cf. 3)
PyTorch: `torch.sum(aten * bten)`
NumPy : `np.einsum("ij, ij -> ", arr1, arr2)`

``````In [120]: torch.einsum("ij, ij -> ", aten, bten)
Out[120]: tensor(1300)
``````