If I have a tensort x with shape [z, d, d], which indicates a series image frames just like video data. Let pz=z**0.5 and let x = x.view(pz, pz, d, d]. Then we can get a grid of images with grid size of pz*pz, and each image has a shape of [d, d]. Now, I want get a matrix or tensor with shape of [1, 1, p*d, p*d], and MUST insure all element keep the same inter-position with all original images.
For an example:
x = [[[ 0, 1],
[ 2, 3]],
[[ 4, 5],
[ 6, 7]],
[[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15]]]
which indicates a series images with shape [2,2] and z = 4 I want get a tensor like:
tensor([[ 0, 1, 4, 5],
[ 2, 3, 6, 7],
[ 8, 9, 12, 13],
[10, 11, 14, 15]])
I can use x = x.view(1, 1, 4, 4) to get one with the same shape,but it likes this:
tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]]])
which I don't want.
And more , How about x has more dimension? Just like [b, c, z, d, d]. How to deal with this?
Any suggestion will be helpful.
I have a solution about the three dimention situation.If x.shape = [z, d, d], then the code below will work. But not work for high dimention tensors. Nested loop will be ok, but too heavy. My solution for three dimention situation:
d = 2
z = 4
b, c = 1, 1
x = torch.arange(z*d*d).view(z, d, d)
# x = torch.tensor([[[ 1, 2],
# [ 4, 6]],
#
# [[ 8, 10],
# [12, 14]],
#
# [[16, 18],
# [20, 22]],
#
# [[24, 26],
# [28, 30]],
#
# [[32, 34],
# [36, 38]],
#
# [[40, 42],
# [44, 46]],
#
# [[48, 50],
# [52, 54]],
#
# [[56, 58],
# [60, 62]],
#
# [[64, 66],
# [68, 70]]])
# make z-index planes to a grid layout
grid_side_len = int(z**0.5)
grid_x = x.view(grid_side_len, grid_side_len, d, d)
# for all rows of crops , horizontally stack them togather
plane = []
for i in range(grid_x.shape[0]):
cat_crops = torch.hstack([crop for crop in grid_x[i]])
plane.append(cat_crops)
plane = torch.vstack([p for p in plane])
print("3D crop to 2D crop plane:")
print(x)
print(plane)
print(plane.shape)
print("2D crop plane to 3D crop:")
# group all rows
split = torch.chunk(plane, plane.shape[1]//d, dim=0)
spat_flatten = torch.cat([torch.cat(torch.chunk(p, p.shape[1]//d, dim=1), dim=0) for p in split], dim=0)
crops = [t[None,:,:] for t in torch.chunk(spat_flatten, spat_flatten.shape[0]//d, dim=0)]
spat_crops = torch.cat(crops, dim=0)
print(spat_crops)
print(spat_crops.shape)
This is an operation that can be solved with a combination of torch.transpose
and torch.reshape
operations. Starting from an arrangement tensor:
>>> x = torch.arange(16).view(4,2,2)
Start by transposing the tensor such that the dimension that you want to collate on is standing "vertically", this can be done with x.transpose(dim0=1, dim1=2)
. Although, I recommend working with negative dimensions instead:
>>> x.transpose(-1,-2)
tensor([[[ 0, 2],
[ 1, 3]],
[[ 4, 6],
[ 5, 7]],
[[ 8, 10],
[ 9, 11]],
[[12, 14],
[13, 15]]])
Then reshape to collate the dimension:
>>> x.transpose(-1,-2).reshape(2,4,2)
tensor([[[ 0, 2],
[ 1, 3],
[ 4, 6],
[ 5, 7]],
[[ 8, 10],
[ 9, 11],
[12, 14],
[13, 15]]])
Then flip back to recover the order of the elements from step 1.:
>>> x.transpose(-1,-2).reshape(2,4,2).transpose(-1,-2)
tensor([[[ 0, 1, 4, 5],
[ 2, 3, 6, 7]],
[[ 8, 9, 12, 13],
[10, 11, 14, 15]]])
Finally, reshape to the desired form:
>>> x.transpose(-1,-2).reshape(2,4,2).transpose(-1,-2).reshape(len(x),-1)
tensor([[ 0, 1, 4, 5],
[ 2, 3, 6, 7],
[ 8, 9, 12, 13],
[10, 11, 14, 15]])
From there you can apply to your needs by changing the dimension sizes and even expanding to higher dimension numbers such as [b, c, z, d, d]
as you described. If you understand this simple approach by playing around with this example you will be able to work out any problem similar to this.