Search code examples
pythonpytorch

Why would we use to() method in pytorch?


I've seen this method multiple times. What are the purposes and advantages of doing this?


Solution

  • Why would we use to(device) method in pytorch?

    torch.Tensor.to is multipurpose method.

    Not only you can do type conversion, but it can also do CPU to GPU tensor move and GPU to CPU tensor move:

    tensor = torch.randn(2, 2)  
    print(tensor)
    tensor = tensor.to(torch.float64) 
    print(tensor) #dtype=torch.float64
    tensor = tensor.to("cuda") 
    print(tensor) #device='cuda:0', dtype=torch.float64)
    tensor = tensor.to("cpu") 
    print(tensor) #dtype=torch.float64
    tensor = tensor.to(torch.float32) 
    print(tensor) # won't print dtype=torch.float32 since it is by default
    

    Since CPU and GPU are different kind memories, there must be a way they communicate. This is why we have to("cuda"), and to("cpu") that we call on tensor.

    Usually when you load training datasets (images):

    You can create tensors and move them to GPU like this.

    torch.zeros(1000).to("cuda")
    

    But there is a trick, sometimes you can even load them directly to GPU without messing the CPU.

    torch.zeros(1000, device="gpu")