I want to implement the following formula in pytorch in a batch manner:
x^T A x
where x has shape: [BATCH, DIM1] and A has shape: [BATCH, DIM1, DIM1]
I managed to implement it for the dense matrix A as follows:
torch.bmm(torch.bmm(x.unsqueeze(1), A), x.unsqueeze(2)).squeeze()
.
However, now I need to implement it for a SPARSE matrix A and I am failing to implement it.
The error that I am getting is {RuntimeError}bmm_sparse: Tensor 'mat2' must be dense
, which comes from the torch.bmm(x.unsqueeze(1), A)
part of the code.
In order to reproduce my work you could run this:
import torch
sparse = True # switch to dense to see the working version
batch_size = 10
dim1 = 5
x = torch.rand(batch_size, dim1)
A = torch.rand(batch_size, dim1, dim1)
if sparse:
A = A.to_sparse_coo()
xTAx = torch.bmm(torch.bmm(x.unsqueeze(1), A), x.unsqueeze(2)).squeeze()
My pytorch version is 1.12.1+cu116
The solution is as simple as changing the order of multiplications from
(xT A) x
to xT (Ax)
.