Search code examples
pythonpytorchtensor

Upsampling 4d tensor in pytorch


I created a 4D tensor of the following size. (B, C, H, W)

And I want to reshape this tensor to the following size. (B, C, 2H, 2W)

Each value expands into 4 values, but only the elements in the original tensor that have an index corresponding to a value in the index tensor remain identical to the original tensor's value.

The index follows the following rule.

0 1

2 3

And here is an example below:

original tensor:
torch.Size([1, 1, 2, 2])
tensor([[[[1.0000, 0.4000],
          [0.2000, 0.5000]]]])

index tensor:
torch.Size([1, 1, 2, 2])
tensor([[[[0, 2],
          [1, 1]]]])
output tensor:
torch.Size([1, 1, 4, 4])
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.4000, 0.0000],
          [0.0000, 0.2000, 0.0000, 0.5000],
          [0.0000, 0.0000, 0.0000, 0.0000]]]])

How can I efficiently transform this tensor while making good utility of GPU? (maybe we can use torch.tensor.scatter_)


Solution

  • My initial thought was that you can utilize a nn.MaxUnpool2d layer to unpool the original tensor x at the desired locations from index. However, you current setup provides those indices locally using your based on your squared 0, 1, 2, 3 layout. This means that in order to effectively use an unpooling layer we first need to properly convert your current index tensor from this:

    tensor([[[[0, 2],
              [1, 1]]]])
    

    To something like:

    tensor([[[[0,  6],
              [9, 11]]]])
    

    Where each index refers to the location the moving window will place the values in x.

    In order to do so we need to work on index by

    1. First convert your index tensor to a one-hot encoding representation to expand the number of elements using nn.functional.one_hot:

      >>> one_hot(index, 4)
      tensor([[[[[1, 0, 0, 0],      # [[[[[a0, a1, a2, a3],
                 [0, 0, 1, 0]],     #     [b0, b1, b2, b3]],
      
                [[0, 1, 0, 0],      #    [[c0, c1, c2, c3],
                 [0, 1, 0, 0]]]]])  #     [d0, d1, d2, d3]]]]]
      

      We then need convert the current tensor to the following layout:

      # [[[[[a0, a1, b0, b1],
      #     [a2, a3, b2, b3]],
      
      #    [[c0, c1, d0, d2],
      #     [c2, c3, d2, d3]]]]]
      
    2. Begin by flattening viewing your tensor as a set of 2x2s:

      >>> ohe = one_hot(index, 4).view(2,2,2,2)
      tensor([[[[1, 0],     # [[[[a0, a1],
                [0, 0]],    #    [a2, a3]],
      
               [[0, 0],     #   [[b0, b1],
                [1, 0]]],   #    [b2, b3]],
      
      
              [[[0, 1],     #  [[[c0, c1],
                [0, 0]],    #   [[c2, c3]],
      
               [[0, 1],     #   [[d0, d1],
                [0, 0]]]])  #    [d2, d3]]]]
      
    3. The series of steps taken now allow use to recover the desired layout without having to split and concatenate the elements. You can look at a detailed explanation of the process on this other answer.

      >>> y = ohe.transpose(-1,-2).reshape(2,-1,2).transpose(-1,-2)
      tensor([[[1, 0, 0, 0],      # [[[[[a0, a1, b0, b1],
               [0, 0, 1, 0]],     #     [a2, a3, b2, b3]],
      
              [[0, 1, 0, 1],      #    [[c0, c1, d0, d2],
               [0, 0, 0, 0]]])    #     [c2, c3, d2, d3]]]]]
      
    4. Then flatten and unravel the indices by multiplying with an arrangement of values:

      >>> y.reshape(-1)
      tensor([ 1,  0,  0,  0,  0,  0,  1,  0,  0,  1,  0,  1,  0,  0,  0, 0])
      
      >>> i = y.reshape(-1)*(torch.arange(4*x.numel())+1)
      tensor([ 1,  0,  0,  0,  0,  0,  7,  0,  0, 10,  0, 12,  0,  0,  0, 0])
      
    5. The desired index tensor is:

      >>> unpool_i = i.nonzero().view_as(x)
      tensor([[[[ 0,  6],
                [ 9, 11]]]])
      
    6. Finally you can apply the unpooling layer with:

      >>> unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
      >>> unpool(x, unpool_i)
      tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
                [0.0000, 0.0000, 0.4000, 0.0000],
                [0.0000, 0.2000, 0.0000, 0.5000],
                [0.0000, 0.0000, 0.0000, 0.0000]]]])