I'm trying to convert the type of tensors from DoubleTensor to FloatTensor. However, it seems like the tensors aren't being converted in my code. How can I convert the tensors using a for loop?
train_tensorset = [xq_train_tensor,y_train_tensor]
val_tensorset = [xq_val_tensor,y_val_tensor]
test_tensorset = [xq_test_tensor,y_test_tensor]
tensor_list = [train_tensorset,val_tensorset,test_tensorset]
flat_tensor_list = list(itertools.chain.from_iterable(tensor_list))
print(f"Num tensors:{len(flat_tensor_list)}")
for i, tensor in enumerate(flat_tensor_list):
tensor = flat_tensor_list[i].float()
print(f"{i}: {tensor.type()}")
print(xq_train_tensor.type())
To verify that the tensors are being converted I check the type of xq_train_tensor
, which is a part of flat_tensor_list
. This was the output of the code above:
Num tensors:6
0: torch.FloatTensor
1: torch.FloatTensor
2: torch.FloatTensor
3: torch.FloatTensor
4: torch.FloatTensor
5: torch.FloatTensor
xq_train_tensor: torch.DoubleTensor
Even though all items in flat_tensor_list
are being converted in the for loop, it doesn't seem like the tensors are actually being converted.
When you call .float()
on a tensor, you return a new object.
When you run print(xq_train_tensor.type())
at the end of your code, you are referencing the old object, which is still double type.
If you want the variable xq_train_tensor
to be updated to float type, you need to reassign the variable itself.
ie xq_train_tensor = xq_train_tensor.float()
If you want to use the list iteration format, you can save the new float tensors to a new list. However, this won't update the xq_train_tensor
varaible, as it still points to the old tensor.
output_tensors = []
for i, tensor in enumerate(flat_tensor_list):
output_tensors.append(tensor.float())
for i, tensor in enumerate(output_tensors):
print(f"{i}: {tensor.type()}")