I was trying something with batch processing in pytorch. In my code below, you may think of x
as a batch of batch size 2 (each sample is a 10d vector). I use x_sep
to denote the first sample in x
.
import torch
import torch.nn as nn
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.fc1 = nn.Linear(10,10)
def forward(self, x):
x = self.fc1(x)
return x
f = net()
x = torch.randn(2,10)
print(f(x[0])==f(x)[0])
Ideally, f(x[0])==f(x)[0]
should give a tensor with all true entries. But the output on my computer is
tensor([False, False, True, True, False, False, False, False, True, False])
Why does this happen? Is it a computational error? Or is it related to how the batch precessing in implemented in pytorch?
Update: I simplified the code a bit. The question remains the same.
My reasoning:
I believe f(x)[0]==f(x[0])
should have all its entries True
because the law of matrix multiplication says so. Let us think of x
as a 2x10 matrix, and think of the linear transformation f()
as represented by matrix B
(ignoring the bias for a moment). Then f(x)=xB
by our notations. The law of matrix multiplication tells us that xB
is equal to first multiply the two rows by B
on the right separately, and then put the two rows back together. Translated back to the code, it is f(x[0])==f(x)[0]
and f(x[1])==f(x)[1]
.
Even if we consider the bias, every row should have the same bias and the equality should still hold.
Also note that no training is done here. Hence how the weights are initialized shouldn't matter.
TL;DR
Under the hood it uses a function named addmm
that have some optimizations, and probably multiply the vectors in a slightly different way
I just understood what was the real issue, and I edited the answer.
After trying to reproduce and debug it on my machine. I found out that:
f(x)[0].detach().numpy()
>>>array([-0.5386441 , 0.4983463 , 0.07970242, 0.53507525, 0.71045876,
0.7791027 , 0.29027492, -0.07919329, -0.12045971, -0.9111403 ],
dtype=float32)
f(x[0]).detach().numpy()
>>>array([-0.5386441 , 0.49834624, 0.07970244, 0.53507525, 0.71045876,
0.7791027 , 0.29027495, -0.07919335, -0.12045971, -0.9111402 ],
dtype=float32)
f(x[0]).detach().numpy() == f(x)[0].detach().numpy()
>>>array([ True, False, False, True, True, True, False, False, True,
False])
If you give a close look, you will find out that all the indices which are False, there is a slight numeric change in 5th floating point.
After some more debugging, I saw in the linear function it uses addmm
:
def linear(input, weight, bias=None):
if input.dim() == 2 and bias is not None:
# fused op is marginally faster
ret = torch.addmm(bias, input, weight.t())
else:
output = input.matmul(weight.t())
if bias is not None:
output += bias
ret = output
return ret
When addmm addmm
, implements beta*mat + alpha*(mat1 @ mat2)
and is supposedly faster (see here for example).