Search code examples
pytorchprecision

PyTorch load_state_dict() does not load precise value


For simplicity say I want to set all params of a torch model to the constant 72114982 with this code

model = Net()
params = model.state_dict()

for k, v in params.items():
    params[k] = torch.full(v.shape, 72114982, dtype=torch.long) 

model.load_state_dict(params)
print(model.state_dict().values())

Then the print statement shows all values actually get set to 72114984 that is 2 off from the one I initially intended.

For simplicity define Net as follows

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(2, 2)

Solution

  • This is an issue of data types.

    Model parameters are cast to float tensors. 72114984 is large enough that its float representation rounds to 72114984.

    You can verify this with the following:

    x = torch.tensor(72114982, dtype=torch.long)
    y = x.float() # y will actually be `72114984.0`
    
    # this returns `True` because x is cast to a float before evaluating
    x == y
    > tensor(True)
    
    # for the same reason, this returns 0.
    y - x
    > tensor(0.)
    
    # this returns `False` because the tensors have different values and we don't cast to float
    x == y.long()
    > tensor(False)
    
    # as longs, the difference correctly evaluates to 2
    y.long() - x
    > tensor(2)