Search code examples
pythondeep-learningpytorchlinear-algebra

Sparse matrix multiplication in pytorch


I want to implement the following formula in pytorch in a batch manner:

x^T A x

where x has shape: [BATCH, DIM1] and A has shape: [BATCH, DIM1, DIM1]

I managed to implement it for the dense matrix A as follows:

torch.bmm(torch.bmm(x.unsqueeze(1), A), x.unsqueeze(2)).squeeze().

However, now I need to implement it for a SPARSE matrix A and I am failing to implement it.

The error that I am getting is {RuntimeError}bmm_sparse: Tensor 'mat2' must be dense, which comes from the torch.bmm(x.unsqueeze(1), A) part of the code.

In order to reproduce my work you could run this:

import torch
sparse = True # switch to dense to see the working version

batch_size = 10
dim1 = 5
x = torch.rand(batch_size, dim1)
A = torch.rand(batch_size, dim1, dim1)

if sparse:
   A = A.to_sparse_coo()
xTAx = torch.bmm(torch.bmm(x.unsqueeze(1), A), x.unsqueeze(2)).squeeze()

My pytorch version is 1.12.1+cu116


Solution

  • The solution is as simple as changing the order of multiplications from (xT A) x to xT (Ax).