Suppose I have a tensor like [A,A,A,A...,A]
.
How can I quickly obtain [[A],[A],[A],...,[A]]
as a tensor in torch?
You can use torch.chunk
as the inverse of cat
, but it looks like you want unsqueeze(1)
:
A = torch.randn(2, 3)
A_rep = (A, A, A, A, A, A, A, A)
catted = torch.cat(A_rep)
#uncatted = torch.chunk(catted, len(A_rep))
catted.unsqueeze(1)