I have been coding an implementation for the following problem: Given two matrices A and B, both with the same shape (bs, n, m, m), I want to compute the following expression in an optimal manner.
out = torch.log(torch.exp(A).sum(dim=1)@torch.exp(B).sum(dim=1))
The problem is that when computing the exponents, they tend to be too large sometimes and I get overflow.
I know about the existence of a code for the expression:
out = torch.log(torch.exp(A)@torch.exp(B))
That I could find here and it works in the regular case where the sum over dim=1 does not exist. I tried to use this code to compute the previous expression but not successfully, I would appreciate any help.
Well, I (and DeepSeek) found a solution for this problem, I hope it will be helpful for someone else.
def stable_implementation(A, B):
log_S_A = torch.logsumexp(A, dim=1) # Shape: (bs, m, m)
log_S_B = torch.logsumexp(B, dim=1) # Shape: (bs, m, m)
combined = log_S_A.unsqueeze(3) + log_S_B.unsqueeze(1) # Shape: (bs, m, m, m)
out = torch.logsumexp(combined, dim=2) # Shape: (bs, m, m)
return out