Search code examples
pythonnumpypytorchtensornumpy-einsum

Understanding PyTorch einsum


I'm familiar with how einsum works in NumPy. A similar functionality is also offered by PyTorch: torch.einsum(). What are the similarities and differences, either in terms of functionality or performance? The information available at PyTorch documentation is rather scanty and doesn't provide any insights regarding this.


Solution

  • Since the description of einsum is skimpy in torch documentation, I decided to write this post to document, compare and contrast how torch.einsum() behaves when compared to numpy.einsum().

    Differences:

    • NumPy allows both small case and capitalized letters [a-zA-Z] for the "subscript string" whereas PyTorch allows only the small case letters [a-z].

    • NumPy accepts nd-arrays, plain Python lists (or tuples), list of lists (or tuple of tuples, list of tuples, tuple of lists) or even PyTorch tensors as operands (i.e. inputs). This is because the operands have only to be array_like and not strictly NumPy nd-arrays. On the contrary, PyTorch expects the operands (i.e. inputs) strictly to be PyTorch tensors. It will throw a TypeError if you pass either plain Python lists/tuples (or its combinations) or NumPy nd-arrays.

    • NumPy supports lot of keyword arguments (for e.g. optimize) in addition to nd-arrays while PyTorch doesn't offer such flexibility yet.

    Here are the implementations of some examples both in PyTorch and NumPy:

    # input tensors to work with
    
    In [16]: vec
    Out[16]: tensor([0, 1, 2, 3])
    
    In [17]: aten
    Out[17]: 
    tensor([[11, 12, 13, 14],
            [21, 22, 23, 24],
            [31, 32, 33, 34],
            [41, 42, 43, 44]])
    
    In [18]: bten
    Out[18]: 
    tensor([[1, 1, 1, 1],
            [2, 2, 2, 2],
            [3, 3, 3, 3],
            [4, 4, 4, 4]])
    

    1) Matrix multiplication
    PyTorch: torch.matmul(aten, bten) ; aten.mm(bten)
    NumPy : np.einsum("ij, jk -> ik", arr1, arr2)

    In [19]: torch.einsum('ij, jk -> ik', aten, bten)
    Out[19]: 
    tensor([[130, 130, 130, 130],
            [230, 230, 230, 230],
            [330, 330, 330, 330],
            [430, 430, 430, 430]])
    

    2) Extract elements along the main-diagonal
    PyTorch: torch.diag(aten)
    NumPy : np.einsum("ii -> i", arr)

    In [28]: torch.einsum('ii -> i', aten)
    Out[28]: tensor([11, 22, 33, 44])
    

    3) Hadamard product (i.e. element-wise product of two tensors)
    PyTorch: aten * bten
    NumPy : np.einsum("ij, ij -> ij", arr1, arr2)

    In [34]: torch.einsum('ij, ij -> ij', aten, bten)
    Out[34]: 
    tensor([[ 11,  12,  13,  14],
            [ 42,  44,  46,  48],
            [ 93,  96,  99, 102],
            [164, 168, 172, 176]])
    

    4) Element-wise squaring
    PyTorch: aten ** 2
    NumPy : np.einsum("ij, ij -> ij", arr, arr)

    In [37]: torch.einsum('ij, ij -> ij', aten, aten)
    Out[37]: 
    tensor([[ 121,  144,  169,  196],
            [ 441,  484,  529,  576],
            [ 961, 1024, 1089, 1156],
            [1681, 1764, 1849, 1936]])
    

    General: Element-wise nth power can be implemented by repeating the subscript string and tensor n times. For e.g., computing element-wise 4th power of a tensor can be done using:

    # NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr)
    In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten)
    Out[38]: 
    tensor([[  14641,   20736,   28561,   38416],
            [ 194481,  234256,  279841,  331776],
            [ 923521, 1048576, 1185921, 1336336],
            [2825761, 3111696, 3418801, 3748096]])
    

    5) Trace (i.e. sum of main-diagonal elements)
    PyTorch: torch.trace(aten)
    NumPy einsum: np.einsum("ii -> ", arr)

    In [44]: torch.einsum('ii -> ', aten)
    Out[44]: tensor(110)
    

    6) Matrix transpose
    PyTorch: torch.transpose(aten, 1, 0)
    NumPy einsum: np.einsum("ij -> ji", arr)

    In [58]: torch.einsum('ij -> ji', aten)
    Out[58]: 
    tensor([[11, 21, 31, 41],
            [12, 22, 32, 42],
            [13, 23, 33, 43],
            [14, 24, 34, 44]])
    

    7) Outer Product (of vectors)
    PyTorch: torch.ger(vec, vec)
    NumPy einsum: np.einsum("i, j -> ij", vec, vec)

    In [73]: torch.einsum('i, j -> ij', vec, vec)
    Out[73]: 
    tensor([[0, 0, 0, 0],
            [0, 1, 2, 3],
            [0, 2, 4, 6],
            [0, 3, 6, 9]])
    

    8) Inner Product (of vectors) PyTorch: torch.dot(vec1, vec2)
    NumPy einsum: np.einsum("i, i -> ", vec1, vec2)

    In [76]: torch.einsum('i, i -> ', vec, vec)
    Out[76]: tensor(14)
    

    9) Sum along axis 0
    PyTorch: torch.sum(aten, 0)
    NumPy einsum: np.einsum("ij -> j", arr)

    In [85]: torch.einsum('ij -> j', aten)
    Out[85]: tensor([104, 108, 112, 116])
    

    10) Sum along axis 1
    PyTorch: torch.sum(aten, 1)
    NumPy einsum: np.einsum("ij -> i", arr)

    In [86]: torch.einsum('ij -> i', aten)
    Out[86]: tensor([ 50,  90, 130, 170])
    

    11) Batch Matrix Multiplication
    PyTorch: torch.bmm(batch_tensor_1, batch_tensor_2)
    NumPy : np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)

    # input batch tensors to work with
    In [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)
    In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4) 
    
    In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)  
    Out[15]: 
    tensor([[[  20,   23,   26,   29],
             [  56,   68,   80,   92],
             [  92,  113,  134,  155],
             [ 128,  158,  188,  218]],
    
            [[ 632,  671,  710,  749],
             [ 776,  824,  872,  920],
             [ 920,  977, 1034, 1091],
             [1064, 1130, 1196, 1262]]])
    
    # sanity check with the shapes
    In [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape 
    Out[16]: torch.Size([2, 4, 4])
    
    # batch matrix multiply using einsum
    In [17]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
    Out[17]: 
    tensor([[[  20,   23,   26,   29],
             [  56,   68,   80,   92],
             [  92,  113,  134,  155],
             [ 128,  158,  188,  218]],
    
            [[ 632,  671,  710,  749],
             [ 776,  824,  872,  920],
             [ 920,  977, 1034, 1091],
             [1064, 1130, 1196, 1262]]])
    
    # sanity check with the shapes
    In [18]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2).shape
    

    12) Sum along axis 2
    PyTorch: torch.sum(batch_ten, 2)
    NumPy einsum: np.einsum("ijk -> ij", arr3D)

    In [99]: torch.einsum("ijk -> ij", batch_ten)
    Out[99]: 
    tensor([[ 50,  90, 130, 170],
            [  4,   8,  12,  16]])
    

    13) Sum all the elements in an nD tensor
    PyTorch: torch.sum(batch_ten)
    NumPy einsum: np.einsum("ijk -> ", arr3D)

    In [101]: torch.einsum("ijk -> ", batch_ten)
    Out[101]: tensor(480)
    

    14) Sum over multiple axes (i.e. marginalization)
    PyTorch: torch.sum(arr, dim=(dim0, dim1, dim2, dim3, dim4, dim6, dim7))
    NumPy: np.einsum("ijklmnop -> n", nDarr)

    # 8D tensor
    In [103]: nDten = torch.randn((3,5,4,6,8,2,7,9))
    In [104]: nDten.shape
    Out[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])
    
    # marginalize out dimension 5 (i.e. "n" here)
    In [111]: esum = torch.einsum("ijklmnop -> n", nDten)
    In [112]: esum
    Out[112]: tensor([  98.6921, -206.0575])
    
    # marginalize out axis 5 (i.e. sum over rest of the axes)
    In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))
    
    In [115]: torch.allclose(tsum, esum)
    Out[115]: True
    

    15) Double Dot Products / Frobenius inner product (same as: torch.sum(hadamard-product) cf. 3)
    PyTorch: torch.sum(aten * bten)
    NumPy : np.einsum("ij, ij -> ", arr1, arr2)

    In [120]: torch.einsum("ij, ij -> ", aten, bten)
    Out[120]: tensor(1300)