Search code examples
pythonmachine-learningpytorchruntime-error

Result type cast error when doing calculations with Pytorch model parameters


When I ran the code below:

import torchvision

model = torchvision.models.densenet201(num_classes=10)
params = model.state_dict()
for var in params:
    params[var] *= 0.1

a RuntimeError was reported:

RuntimeError: result type Float can't be cast to the desired output type Long

But when I changed params[var] *= 0.1 to params[var] = params[var] * 0.1, the error disappears.

Why would this happen?

I thought params[var] *= 0.1 had the same effect as params[var] = params[var] * 0.1.


Solution

  • First, let us know the first long-type parameter in densenet201, you will find the features.norm0.num_batches_tracked which indicates the number of mini-batches during training used to calculate the mean and variance if there is BatchNormalization layer in the model. This parameter is a long-type number and cannot be float type because it behaves like a counter.

    Second, in PyTorch, there are two types of operations:

    • Non-Inplace operations: you assign the new output after calculation to a new copy from the variable, e.g. x = x + 1 or x = x / 2. The memory location of x before assignment not equal to the memory location after assignment because you have a copy from the original variable.
    • Inplace operations: when the calculations directly applied to the original copy of the variable without making any copy here e.g. x += 1 or x /= 2.

    Let's move to your example to understand what happened:

    1. Non-Inplcae operation:

      model = torchvision.models.densenet201(num_classes=10)
      params = model.state_dict()
      name = 'features.norm0.num_batches_tracked'
      
      print(id(params[name]))  # 140247785908560
      params[name] = params[name] + 0.1
      print(id(params[name]))  # 140247785908368  
      print(params[name].type()) # changed to torch.FloatTensor
      
    2. Inplace operation:

      print(id(params[name]))  # 140247785908560
      params[name] += 1
      print(id(params[name]))  # 140247785908560 
      print(params[name].type()) # still torch.LongTensor
      
      params[name] += 0.1     # you want to change the original copy type to float ,you got an error
      

    Finally, some remarks:

    • In-place operations save some memory, but can be problematic when computing derivatives because of an immediate loss of history. Hence, their use is discouraged. Source
    • You should be cautious when you decide to use in-place operations since they overwrite the original content.
    • If you use pandas, this is a bit similar to the inplace=True in pandas :).

    This is a good resource to read more about in-place operation source and read also this discussion source.