Search code examples
pythonpytorchtensor

Turning list of 2D tensors with different length to one 3D tensor


I have a list of 3 tensors with the shape: (8, 2), (8, 4), (8, 6)

And I want to turn this list into this shape: (8, 3, x)

How do I do this? I know I need to use some combination of torch.cat, torch.stack and torch.transpose, but I can't figure it out.

Thanks in advance!


Solution

  • As you said, you need to use torch.cat, but also torch.reshape. Assume the following:

    a = torch.rand(8,2)
    b = torch.rand(8,4)
    c = torch.rand(8,6)
    

    And assume that it is indeed possible to reshape the tensors to a (8,3,-1) shape, where -1 stands for as long as it need to be, then:

    d = torch.cat((a,b,c), dim=1)
    e = torch.reshape(d, (8,3,-1))
    

    I'll explain. Because the 1st dimension if different in a,b,c the concatenation has to be along the 1st dimension, as seen in variable d. Then, you can reshape the tensor as seen in e where the -1 stands for "as long as it needs to be".