Search code examples
pythonpython-3.xpytorch

Reshaping torch tensor


I have a torch of shape (1,3,8). I want to increase the first dimension to n, resulting in the final tensor of shape (n,3,8). I want to pad zeroes of that shape. Here is what I worked on:

n = 5
a = torch.randn(1,3,8) # Random (1,3,8) tensor
b = torch.cat((a,torch.zeros_like(a)))
for i in range(n-2):
    b = torch.cat((b,torch.zeros_like(a)))
print(b.shape) # (5,3,8)

This works, but is there a better and more elegant solution?


Solution

  • You can avoid the loop by creating a tensor of zeros of length n-1 straight away:

    torch.cat((a, torch.zeros(n - 1, a.shape[1], a.shape[2])))