Search code examples
pythonmemorypytorchcontiguous

What does .contiguous() do in PyTorch?


What does x.contiguous() do for a tensor x?


Solution

  • There are a few operations on Tensors in PyTorch that do not change the contents of a tensor, but change the way the data is organized. These operations include:

    narrow(), view(), expand() and transpose()

    For example: when you call transpose(), PyTorch doesn't generate a new tensor with a new layout, it just modifies meta information in the Tensor object so that the offset and stride describe the desired new shape. In this example, the transposed tensor and original tensor share the same memory:

    x = torch.randn(3,2)
    y = torch.transpose(x, 0, 1)
    x[0, 0] = 42
    print(y[0,0])
    # prints 42
    

    This is where the concept of contiguous comes in. In the example above, x is contiguous but y is not because its memory layout is different to that of a tensor of same shape made from scratch. Note that the word "contiguous" is a bit misleading because it's not that the content of the tensor is spread out around disconnected blocks of memory. Here bytes are still allocated in one block of memory but the order of the elements is different!

    When you call contiguous(), it actually makes a copy of the tensor such that the order of its elements in memory is the same as if it had been created from scratch with the same data.

    Normally you don't need to worry about this. You're generally safe to assume everything will work, and wait until you get a RuntimeError: input is not contiguous where PyTorch expects a contiguous tensor to add a call to contiguous().