Search code examples
pytorchrepeat

How to repeat tensor in a specific new dimension in PyTorch


If I have a tensor A which has shape [M, N], I want to repeat the tensor K times so that the result B has shape [M, K, N] and each slice B[:, k, :] should has the same data as A. Which is the best practice without a for loop. K might be in other dimension.

torch.repeat_interleave() and tensor.repeat() does not seem to work. Or I am using it in a wrong way.


Solution

  • tensor.repeat should suit your needs but you need to insert a unitary dimension first. For this we could use either tensor.unsqueeze or tensor.reshape. Since unsqueeze is specifically defined to insert a unitary dimension we will use that.

    B = A.unsqueeze(1).repeat(1, K, 1)
    

    Code Description A.unsqueeze(1) turns A from an [M, N] to [M, 1, N] and .repeat(1, K, 1) repeats the tensor K times along the second dimension.