Search code examples
pythonnumpypytorchbatch-processing

Breaking down a batch in pytorch leads to different results, why?


I was trying something with batch processing in pytorch. In my code below, you may think of x as a batch of batch size 2 (each sample is a 10d vector). I use x_sep to denote the first sample in x.

import torch
import torch.nn as nn

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.fc1 = nn.Linear(10,10)

    def forward(self, x):
        x = self.fc1(x)
        return x

f = net()

x = torch.randn(2,10)
print(f(x[0])==f(x)[0])

Ideally, f(x[0])==f(x)[0] should give a tensor with all true entries. But the output on my computer is

tensor([False, False,  True,  True, False, False, False, False,  True, False])

Why does this happen? Is it a computational error? Or is it related to how the batch precessing in implemented in pytorch?


Update: I simplified the code a bit. The question remains the same.

My reasoning: I believe f(x)[0]==f(x[0]) should have all its entries True because the law of matrix multiplication says so. Let us think of x as a 2x10 matrix, and think of the linear transformation f() as represented by matrix B (ignoring the bias for a moment). Then f(x)=xB by our notations. The law of matrix multiplication tells us that xB is equal to first multiply the two rows by B on the right separately, and then put the two rows back together. Translated back to the code, it is f(x[0])==f(x)[0] and f(x[1])==f(x)[1].

Even if we consider the bias, every row should have the same bias and the equality should still hold.

Also note that no training is done here. Hence how the weights are initialized shouldn't matter.


Solution

  • TL;DR

    Under the hood it uses a function named addmm that have some optimizations, and probably multiply the vectors in a slightly different way


    I just understood what was the real issue, and I edited the answer.

    After trying to reproduce and debug it on my machine. I found out that:

    f(x)[0].detach().numpy()
    >>>array([-0.5386441 ,  0.4983463 ,  0.07970242,  0.53507525,  0.71045876,
            0.7791027 ,  0.29027492, -0.07919329, -0.12045971, -0.9111403 ],
          dtype=float32)
    f(x[0]).detach().numpy()
    >>>array([-0.5386441 ,  0.49834624,  0.07970244,  0.53507525,  0.71045876,
            0.7791027 ,  0.29027495, -0.07919335, -0.12045971, -0.9111402 ],
          dtype=float32)
    f(x[0]).detach().numpy() == f(x)[0].detach().numpy()
    >>>array([ True, False, False,  True,  True,  True, False, False,  True,
       False])
    

    If you give a close look, you will find out that all the indices which are False, there is a slight numeric change in 5th floating point.

    After some more debugging, I saw in the linear function it uses addmm:

    def linear(input, weight, bias=None):
        if input.dim() == 2 and bias is not None:
            # fused op is marginally faster
            ret = torch.addmm(bias, input, weight.t())
        else:
            output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret
    

    When addmm addmm, implements beta*mat + alpha*(mat1 @ mat2) and is supposedly faster (see here for example).

    Credit to Szymon Maszke