Search code examples
pytorch

How can I implement a torch.linalg.matmul operation in the forward function for batch sizes larger than 1?


I want to fit a model with torch.matrix_power operation in the forward method of a neural network class. However, I can only fit batch sizes of 1, as torch.matrix_power will only accept a scalar power.

k is a limited number of integers, e.g. between 1-10. I tried precomputing the matrix_powers to avoid the costly power calculation in each iteration, but I got an error the second time the precomputed Ak matrix was called.

A small self-contained example of what I'm trying to do is below:

class NN(nn.Module):
  def __init__(self, dim_x):
        self.A = nn.Parameter(torch.randn(dim_x, dim_x))

  def forward(self, X, k):
        k = k.item()
        A_pow = torch.matrix_power(self.A, k)
        return X @ A_pow

Solution

  • I think index_select does what you need. Consider this, based on your example:

    class NN(nn.Module):
        def __init__(self, dim_x: int, max_k: int, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.A = nn.Parameter(torch.randn(dim_x, dim_x))
            self.A_k = torch.concat([torch.matrix_power(self.A, k).unsqueeze(dim=0) for k in torch.arange(max_k)], dim=0)
    
        def forward(self, X, k):
            A_pow = torch.index_select(self.A_k, dim=0, index=k)
            return X @ A_pow
    

    I calculated ahead of time the powers of A up to max_k. During the forward pass, I used index_select to choose a subset of those according to the vector k