Search code examples
pythonpytorchtorch

Reverse operation of torch.cat


Suppose I have a tensor like [A,A,A,A...,A].

How can I quickly obtain [[A],[A],[A],...,[A]] as a tensor in torch?


Solution

  • You can use torch.chunk as the inverse of cat, but it looks like you want unsqueeze(1):

    A = torch.randn(2, 3)
    A_rep = (A, A, A, A, A, A, A, A)
    catted = torch.cat(A_rep)
    
    #uncatted = torch.chunk(catted, len(A_rep))
    catted.unsqueeze(1)