Search code examples
pythonpytorchmatmul

Why is the result of matrix multiply in torch so different when i roll the matrix?


Although there is a problem with the accuracy of floating-point multiplication, the gap is slightly larger. And it is also related to the roll step.

x = torch.rand((1, 5))
y = torch.rand((5, 1))
print("%.10f"%torch.matmul(x,y))
>>> 1.2710412741
print("%.10f"%torch.matmul(torch.roll(x, 1, 1), torch.roll(y, 1, 0)))
>>> 1.2710412741
print("%.10f"%torch.matmul(torch.roll(x, 2, 1), torch.roll(y, 2, 0)))
>>> 1.2710413933

What results in the problem above? How can i get a more consistent result?


Solution

  • Floating point additions are not associative, hence you're not guaranteed to get the same result for different orders of the summands.

    If you want to eliminate this, you can use something like the Kahan algorithm.

    But this all comes with a big caveat: If you really have to rely on this, you should think about using different ways of representing your numbers, see the first link. Floating point numbers are nice for numerical computations, but if you use them, you have to deal with all kinds of different sources of error. Again, I recommend thoroughly reading the first linked page, and also familiarizing yourself with the inner workings of floating point numbers, e.g. https://en.wikipedia.org/wiki/Floating-point_arithmetic