The shape of the tensor is [5, 2, 18, 4096]
. I want to take each tensor along 0th dimension of size [2, 18, 4096]
and stack it on top of another tensor which is of shape from the same tensor [2, 18, 4096]
and do it for all tensors along the 0th dimension. The final tensor should be [2, 90, 4096]
.
I did get to a general approach in solving this, but is there any better way to do this? Also, is it mathematically correct too?
chunks = torch.Tensor(self.buffer) #shape is [5, 2, 18, 4096]
chunks = chunks.permute(1, 0, 2, 3)
chunks = chunks.reshape(chunks.shape[0], chunks.shape[1]*chunks.shape[2], chunks.shape[-1])
#the resulting shape is [2, 90, 4096]