Search code examples
pythonnumpypytorch

Unflatten in pytorch


I need to change the shape of tensor from [2, 48, 196] to [2, 48, 14,14]. I read there a "unflatten" in pytorch. But I couldn't understand how to use it. Is there any example?


Solution

  • Here is example for your question.

    import torch
    
    input = torch.randn([2,48,196])
    unflatten = torch.nn.Unflatten(2, (14,14))
    output = unflatten(input)
    

    If you check output.shape, the shape is [2,48,14,14].

    Unflatten function is to expand specific dim to a desired shape. In your case, you want to expand the shape 196 in "dim 2" to new shape of the unflatten dimension "(14,14)".

    There are two parameters in Unflatten function.

    1. First parameter is dim. it is specific dimension which you want to be unflatten. In your case, it is 2.
    2. Second parameter is unflatten_size. It is the new shape of the unflatten dimension of the tensor. So it is (14,14).

    Therefore, your Unflatten function should be looked like unflatten = torch.nn.Unflatten(2, (14,14))