I created a 4D tensor of the following size. (B, C, H, W)
And I want to reshape this tensor to the following size. (B, C, 2H, 2W)
Each value expands into 4 values, but only the elements in the original tensor that have an index corresponding to a value in the index tensor remain identical to the original tensor's value.
The index follows the following rule.
0 1
2 3
And here is an example below:
original tensor:
torch.Size([1, 1, 2, 2])
tensor([[[[1.0000, 0.4000],
[0.2000, 0.5000]]]])
index tensor:
torch.Size([1, 1, 2, 2])
tensor([[[[0, 2],
[1, 1]]]])
output tensor:
torch.Size([1, 1, 4, 4])
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.4000, 0.0000],
[0.0000, 0.2000, 0.0000, 0.5000],
[0.0000, 0.0000, 0.0000, 0.0000]]]])
How can I efficiently transform this tensor while making good utility of GPU? (maybe we can use torch.tensor.scatter_)
My initial thought was that you can utilize a nn.MaxUnpool2d
layer to unpool the original tensor x
at the desired locations from index
. However, you current setup provides those indices locally using your based on your squared 0, 1, 2, 3
layout. This means that in order to effectively use an unpooling layer we first need to properly convert your current index tensor from this:
tensor([[[[0, 2],
[1, 1]]]])
To something like:
tensor([[[[0, 6],
[9, 11]]]])
Where each index refers to the location the moving window will place the values in x
.
In order to do so we need to work on index
by
First convert your index tensor to a one-hot encoding representation to expand the number of elements using nn.functional.one_hot
:
>>> one_hot(index, 4)
tensor([[[[[1, 0, 0, 0], # [[[[[a0, a1, a2, a3],
[0, 0, 1, 0]], # [b0, b1, b2, b3]],
[[0, 1, 0, 0], # [[c0, c1, c2, c3],
[0, 1, 0, 0]]]]]) # [d0, d1, d2, d3]]]]]
We then need convert the current tensor to the following layout:
# [[[[[a0, a1, b0, b1],
# [a2, a3, b2, b3]],
# [[c0, c1, d0, d2],
# [c2, c3, d2, d3]]]]]
Begin by flattening viewing your tensor as a set of 2x2s:
>>> ohe = one_hot(index, 4).view(2,2,2,2)
tensor([[[[1, 0], # [[[[a0, a1],
[0, 0]], # [a2, a3]],
[[0, 0], # [[b0, b1],
[1, 0]]], # [b2, b3]],
[[[0, 1], # [[[c0, c1],
[0, 0]], # [[c2, c3]],
[[0, 1], # [[d0, d1],
[0, 0]]]]) # [d2, d3]]]]
The series of steps taken now allow use to recover the desired layout without having to split and concatenate the elements. You can look at a detailed explanation of the process on this other answer.
>>> y = ohe.transpose(-1,-2).reshape(2,-1,2).transpose(-1,-2)
tensor([[[1, 0, 0, 0], # [[[[[a0, a1, b0, b1],
[0, 0, 1, 0]], # [a2, a3, b2, b3]],
[[0, 1, 0, 1], # [[c0, c1, d0, d2],
[0, 0, 0, 0]]]) # [c2, c3, d2, d3]]]]]
Then flatten and unravel the indices by multiplying with an arrangement of values:
>>> y.reshape(-1)
tensor([ 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0])
>>> i = y.reshape(-1)*(torch.arange(4*x.numel())+1)
tensor([ 1, 0, 0, 0, 0, 0, 7, 0, 0, 10, 0, 12, 0, 0, 0, 0])
The desired index tensor is:
>>> unpool_i = i.nonzero().view_as(x)
tensor([[[[ 0, 6],
[ 9, 11]]]])
Finally you can apply the unpooling layer with:
>>> unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
>>> unpool(x, unpool_i)
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.4000, 0.0000],
[0.0000, 0.2000, 0.0000, 0.5000],
[0.0000, 0.0000, 0.0000, 0.0000]]]])