I wrote below code using PyTorch and ran into the runtime error:
tns = torch.tensor([1,0,1])
tns.mean()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-666-194e5ab56931> in <module>
----> 1 tns.mean()
RuntimeError: mean(): input dtype should be either floating point or complex dtypes. Got Long instead.
However, if I change the tensor to float
, the error goes away:
tns = torch.tensor([1.,0,1])
tns.mean()
---------------------------------------------------------------------------
tensor(0.6667)
My question is why the error happens. The data type of the first tensor is int64
instead of Long
, why does PyTorch take it as Long
?
This is because torch.int64
and torch.long
both refer to the same data type, of 64-bit signed integers. See here for an overview of all data types.