Search code examples
pythonmachine-learningdeep-learningpytorchtransformer-model

I have a tensor of shape [5, 2, 18, 4096]. I want to stack the 0th dimension along the 2nd dimension. How can I do it?


The shape of the tensor is [5, 2, 18, 4096]. I want to take each tensor along 0th dimension of size [2, 18, 4096] and stack it on top of another tensor which is of shape from the same tensor [2, 18, 4096] and do it for all tensors along the 0th dimension. The final tensor should be [2, 90, 4096].


Solution

  • I did get to a general approach in solving this, but is there any better way to do this? Also, is it mathematically correct too?

    chunks = torch.Tensor(self.buffer) #shape is [5, 2, 18, 4096]
    chunks = chunks.permute(1, 0, 2, 3)
    chunks = chunks.reshape(chunks.shape[0], chunks.shape[1]*chunks.shape[2], chunks.shape[-1]) 
    #the resulting shape is [2, 90, 4096]