Search code examples
arrayspytorchdataset

PyTorch: difference between reshape() and view() method


What is the difference between reshape and view method and why do we need and I am using pytorch tensors and working on changing the shape of data then I came to know these two functions. what are the affects on memory which consumes the more memory and which is more expensive if we are working on large data with less resources.

x = torch.tensor([1, 2, 3, 4, 5], dtype = torch.float32)
x = x.reshape(-1, 1)

and the view method

x = x.view(x.shape[0], 1)

What is the difference and which should I use


Solution

  • The short answer: When reshaping a contiguous tensor, both methods will do the same (namely, provide a new view of the given tensor) and can therefore be used interchangeably. When reshaping a non-contiguous tensor, reshape() will duplicate the necessary parts of memory from the given tensor for producing the resulting tensor, while view() will fail with a RuntimeError.

    The long answer

    The main difference is how torch.Tensor.reshape() and torch.Tensor.view() handle non-contiguous tensors.

    To understand the difference, we need to understand what is a contiguous tensor, and what is a view of a tensor:

    • A contiguous tensor is a tensor whose values are stored in a single, uninterrupted – thus, "contiguous" – piece of memory. A non-contiguous tensor may have gaps in its memory layout.
    • Producing a view of a tensor means reinterpreting the arrangement of values in its memory. Think of a piece of memory that stores 16 values: we can interpret it, for example, as a 16-element 1-d tensor or as a 4×4-element 2-d tensor. Both interpretations can use the same underlying memory. By using views and thus reinterpreting the memory layout, we can create differently shaped tensors from the same piece of memory, in this way avoiding duplication and saving memory.

    Now back to the two methods:

    • If applied to a contiguous tensor, both reshape() and view() will produce a new view of the given tensor's memory, in this way avoiding duplication.
    • If applied to a non-contiguous tensor, the creation of a view will not be possible. The two methods handle this situation differently:
      • The reshape() method will duplicate the necessary piece of memory and will return a tensor whose memory will not be shared with the given tensor.
      • The view() method will produce a RuntimeError.

    We can demonstrate this with the following piece of code:

    from torch import arange
    
    contiguous = arange(16).view(4, 4)             # Create contiguous 4×4 tensor
    noncontiguous = arange(20).view(4, 5)[:, :4]   # Create non-contiguous 4×4 tensor
    
    contiguous_r = contiguous.reshape(16)          # OK: produces a 1-d view
    assert contiguous_r.data_ptr() == contiguous.data_ptr()  # Same memory used
    
    contiguous_v = contiguous.view(16)             # OK: produces a 1-d view
    assert contiguous_v.data_ptr() == contiguous.data_ptr()  # Same memory used
    
    noncontiguous_r = noncontiguous.reshape(16)    # OK: produces a new 1-d array
    assert noncontiguous_r.data_ptr() != noncontiguous.data_ptr()  # New memory used
    
    noncontiguous_v = noncontiguous.view(16)       # ERROR: cannot produce view
    

    The last line will produce RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

    Maybe at this point, I should also mention what a tensor's stride is: in essence, it is the information that tells us how to map the tensor's indexes to its underlying memory. You will find more information on strides in particular and on contiguous vs. non-contiguous tensors in general, for example, in this discussion in the PyTorch forum.

    As to your question, which should I use?:

    • I would recommend using reshape() if you just want to make sure that you will receive a result when reshaping a tensor: For a contiguous tensor, reshape() will do exactly the same as view() (namely, produce a view on the given tensor and not duplicate its memory). For a non-contiguous tensor, reshape() will be the only method that will produce a result, while view() will fail (see above).
    • I would recommend using view() if you are actually sure that your tensors to be reshaped are contiguous and want to catch the situation where they aren't. This may be meaningful, for example, if you work in a low-memory regime and thus rather prefer failing than duplicating memory. Another use case of view() is not reinterpreting the shape but the data type of the underlying memory. I guess this is not your use case, but you will find more on this in the documentation of view() that I linked above.