Search code examples
pythonpytorchruntime-errormeantensor

RuntimeError: mean(): input dtype should be either floating point or complex dtypes. Got Long instead


I wrote below code using PyTorch and ran into the runtime error:

tns = torch.tensor([1,0,1])
tns.mean()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-666-194e5ab56931> in <module>
----> 1 tns.mean()

RuntimeError: mean(): input dtype should be either floating point or complex dtypes. Got Long instead.

However, if I change the tensor to float, the error goes away:

tns = torch.tensor([1.,0,1])
tns.mean()
---------------------------------------------------------------------------
tensor(0.6667)

My question is why the error happens. The data type of the first tensor is int64 instead of Long, why does PyTorch take it as Long?


Solution

  • This is because torch.int64 and torch.long both refer to the same data type, of 64-bit signed integers. See here for an overview of all data types.