Search code examples
pythonpytorchartificial-intelligence

How to understand torch.arange(0, 3).view(-1, *[1]*3)


In Pytorch, for the code:

torch.arange(0, 3).view(-1, *[1]*3)

The result is:

tensor([[[[0]]],


    [[[1]]],


    [[[2]]]])


    torch.Size([3, 1, 1, 1])

Where [1] * 3 = [1, 1, 1], but I don`t understand the * before [1] * 3. What is the meaning of it? Thanks.


Solution

  • While links provided in the comments describe parts of the solution, whole thing might be missing, hence, let’s disentangle this view method:

    .view(-1,...)
    

    Means “all the elements”, in your case it is 3 as you have [0, 1, 2] with length of 3.

    Next:

    [1] * 3
    

    Is a Python trick to create new list with single element repeated multiple times.

    It is the same as

    [1, 1, 1]
    

    Now unpacking with asterisk “unpacks” values as arguments to function, in this case:

    .view(-1, [1, 1, 1])
    

    Becomes:

    .view(-1, 1, 1, 1)
    

    And the whole thing is (according to first step):

    .view(3, 1, 1, 1)
    

    BTW. Please don't do that under most circumstances, it’s pretty hard to follow as one can see above.