Search code examples
pythonfor-looppytorchtensor

How to convert the tensor type of a list of pytorch tensors using a for loop


I'm trying to convert the type of tensors from DoubleTensor to FloatTensor. However, it seems like the tensors aren't being converted in my code. How can I convert the tensors using a for loop?

train_tensorset = [xq_train_tensor,y_train_tensor]
val_tensorset = [xq_val_tensor,y_val_tensor]
test_tensorset = [xq_test_tensor,y_test_tensor]

tensor_list = [train_tensorset,val_tensorset,test_tensorset]
flat_tensor_list = list(itertools.chain.from_iterable(tensor_list))
print(f"Num tensors:{len(flat_tensor_list)}")

for i, tensor in enumerate(flat_tensor_list):
    tensor = flat_tensor_list[i].float()
    print(f"{i}: {tensor.type()}") 

print(xq_train_tensor.type())

To verify that the tensors are being converted I check the type of xq_train_tensor, which is a part of flat_tensor_list. This was the output of the code above:

Num tensors:6
0: torch.FloatTensor
1: torch.FloatTensor
2: torch.FloatTensor
3: torch.FloatTensor
4: torch.FloatTensor
5: torch.FloatTensor
xq_train_tensor: torch.DoubleTensor

Even though all items in flat_tensor_list are being converted in the for loop, it doesn't seem like the tensors are actually being converted.


Solution

  • When you call .float() on a tensor, you return a new object.

    When you run print(xq_train_tensor.type()) at the end of your code, you are referencing the old object, which is still double type.

    If you want the variable xq_train_tensor to be updated to float type, you need to reassign the variable itself.

    ie xq_train_tensor = xq_train_tensor.float()

    If you want to use the list iteration format, you can save the new float tensors to a new list. However, this won't update the xq_train_tensor varaible, as it still points to the old tensor.

    output_tensors = []
    for i, tensor in enumerate(flat_tensor_list):
        output_tensors.append(tensor.float())
    
    for i, tensor in enumerate(output_tensors):
        print(f"{i}: {tensor.type()}")