Search code examples
pythonmatrixpytorch

Efficient implementation of matrix operation in pytorch


I would like to be able to write this more efficiently, ideally using vectorisation, is there a way to do it?

m = 100
n = 100
b = torch.rand(m)
a = torch.rand(m)
sumation = 0
A = torch.rand(n,n)
for i in range(m):
    sumation = sumation +  a[i]/(A - b[i]*torch.eye(n))
print(sumation)
    

I have tried this

# I have also tried something like this

torch.sum(torch.stack([a[i]/(A - b[i]*torch.eye(n)) for i in range(m)], dim=0), dim = 0)

# But this is a false assertation at some elements of the matrix output 
torch.sum(torch.stack([a[i]/(A - b[i]*torch.eye(n)) for i in range(m)], dim=0), dim = 0) == sumation

Solution

  • You can do this

    m = 100
    n = 100
    b = torch.rand(m)
    a = torch.rand(m)
    
    B = torch.eye(n).unsqueeze(0) * b.unsqueeze(1).unsqueeze(2)
    A_minus_B = A.unsqueeze(0) - B
    summation = torch.sum(a.unsqueeze(1).unsqueeze(2) / A_minus_B, dim=0)
    

    Note that due to numerics, the result won't be identical (ie (summation_old == summation_new).all() will be False), but torch.allclose(summation_old, summation_new) will return True.