Search code examples
pythonnumpytorch

Using einsum for transpose times matrix times transpose: x@A@x^T


So I have m number of different vectors (say x), each one is (1,n), stacked horizontally, totally in a (m,n) matrix we call it B, and a matrix (A) with dimension (n,n).

I want to compute xAx^T for all vectors x, output should be (m,1)

How can I write an einsum query for this given B and A?

Here is a sample without einsum:

    import torch
    m = 30
    n = 4
    B = torch.randn(m, n)
    A = torch.randn(n, n)
    result = torch.zeros(m,1)
    for i in range(m):
        x = B[i].unsqueeze(0)
        result[i] = torch.matmul(x, torch.matmul(A, x.T))

Solution

  • Using einsum and tensordot :

    import torch
    import numpy as np
    
    A = torch.tensor([[1.0, 0.0, 0.0, 0.0],
                      [0.0, 2.0, 0.0, 0.0],
                      [0.0, 0.0, 3.0, 0.0],
                      [0.0, 0.0, 0.0, 4.0]])
    
    B = torch.tensor([[1.0, 4.0, 7.0, 10.0],
                      [2.0, 5.0, 8.0, 11.0],
                      [3.0, 6.0, 9.0, 12.0]])
    
    # Using einsum
    
    res_using_einsum = torch.einsum('bi,ij,bj -> b',B,A,B).unsqueeze(dim = -1)
    print(res_using_einsum)
    '''
    tensor([[580.],
            [730.],
            [900.]])
    '''
    
    # Using tensordot
    
    BA = torch.tensordot(B, A, dims = ([1],[0]))
    '''
    BA :
    tensor([[ 1.,  8., 21., 40.],
            [ 2., 10., 24., 44.],
            [ 3., 12., 27., 48.]])
    '''
    res_using_tensordot = torch.tensordot(BA, B, dims =([1],[1]))#.unsqueeze(-1)
    
    '''
    res_using_tensordot :
    
    tensor([[580., 650., 720.],
            [650., 730., 810.],
            [720., 810., 900.]])
    '''
    diagonal_result = torch.diagonal(res_using_tensordot, 0).unsqueeze(1)
    '''
    tensor([[580.],
            [730.],
            [900.]])
    
    '''