Search code examples
pythonpytorch

Matrix multiplication in log space for sum of exponents


I have been coding an implementation for the following problem: Given two matrices A and B, both with the same shape (bs, n, m, m), I want to compute the following expression in an optimal manner.

out = torch.log(torch.exp(A).sum(dim=1)@torch.exp(B).sum(dim=1))

The problem is that when computing the exponents, they tend to be too large sometimes and I get overflow.

I know about the existence of a code for the expression:

out = torch.log(torch.exp(A)@torch.exp(B))

That I could find here and it works in the regular case where the sum over dim=1 does not exist. I tried to use this code to compute the previous expression but not successfully, I would appreciate any help.


Solution

  • Well, I (and DeepSeek) found a solution for this problem, I hope it will be helpful for someone else.

    def stable_implementation(A, B):
       log_S_A = torch.logsumexp(A, dim=1)  # Shape: (bs, m, m)
       log_S_B = torch.logsumexp(B, dim=1)  # Shape: (bs, m, m)
       combined = log_S_A.unsqueeze(3) + log_S_B.unsqueeze(1)  # Shape: (bs, m, m, m)
       out = torch.logsumexp(combined, dim=2)  # Shape: (bs, m, m)
       return out