So I have m number of different vectors (say x), each one is (1,n), stacked horizontally, totally in a (m,n) matrix we call it B, and a matrix (A) with dimension (n,n).
I want to compute xAx^T for all vectors x, output should be (m,1)
How can I write an einsum
query for this given B and A?
Here is a sample without einsum
:
import torch
m = 30
n = 4
B = torch.randn(m, n)
A = torch.randn(n, n)
result = torch.zeros(m,1)
for i in range(m):
x = B[i].unsqueeze(0)
result[i] = torch.matmul(x, torch.matmul(A, x.T))
Using einsum and tensordot :
import torch
import numpy as np
A = torch.tensor([[1.0, 0.0, 0.0, 0.0],
[0.0, 2.0, 0.0, 0.0],
[0.0, 0.0, 3.0, 0.0],
[0.0, 0.0, 0.0, 4.0]])
B = torch.tensor([[1.0, 4.0, 7.0, 10.0],
[2.0, 5.0, 8.0, 11.0],
[3.0, 6.0, 9.0, 12.0]])
# Using einsum
res_using_einsum = torch.einsum('bi,ij,bj -> b',B,A,B).unsqueeze(dim = -1)
print(res_using_einsum)
'''
tensor([[580.],
[730.],
[900.]])
'''
# Using tensordot
BA = torch.tensordot(B, A, dims = ([1],[0]))
'''
BA :
tensor([[ 1., 8., 21., 40.],
[ 2., 10., 24., 44.],
[ 3., 12., 27., 48.]])
'''
res_using_tensordot = torch.tensordot(BA, B, dims =([1],[1]))#.unsqueeze(-1)
'''
res_using_tensordot :
tensor([[580., 650., 720.],
[650., 730., 810.],
[720., 810., 900.]])
'''
diagonal_result = torch.diagonal(res_using_tensordot, 0).unsqueeze(1)
'''
tensor([[580.],
[730.],
[900.]])
'''