Search code examples
pythonmachine-learningmemorypytorchtensor

What does `view()` do in PyTorch?


What does view() do to the tensor x? What do negative values mean?

x = x.view(-1, 16 * 5 * 5)

Solution

  • view() reshapes the tensor without copying memory, similar to numpy's reshape().

    Given a tensor a with 16 elements:

    import torch
    a = torch.range(1, 16)
    

    To reshape this tensor to make it a 4 x 4 tensor, use:

    a = a.view(4, 4)
    

    Now a will be a 4 x 4 tensor. Note that after the reshape the total number of elements need to remain the same. Reshaping the tensor a to a 3 x 5 tensor would not be appropriate.

    What is the meaning of parameter -1?

    If there is any situation that you don't know how many rows you want but are sure of the number of columns, then you can specify this with a -1. (Note that you can extend this to tensors with more dimensions. Only one of the axis value can be -1). This is a way of telling the library: "give me a tensor that has these many columns and you compute the appropriate number of rows that is necessary to make this happen".

    This can be seen in this model definition code. After the line x = self.pool(F.relu(self.conv2(x))) in the forward function, you will have a 16 depth feature map. You have to flatten this to give it to the fully connected layer. So you tell PyTorch to reshape the tensor you obtained to have specific number of columns and tell it to decide the number of rows by itself.