I have a list of 3 tensors with the shape: (8, 2), (8, 4), (8, 6)
And I want to turn this list into this shape: (8, 3, x)
How do I do this? I know I need to use some combination of torch.cat
, torch.stack
and torch.transpose
, but I can't figure it out.
Thanks in advance!
As you said, you need to use torch.cat
, but also torch.reshape
. Assume the following:
a = torch.rand(8,2)
b = torch.rand(8,4)
c = torch.rand(8,6)
And assume that it is indeed possible to reshape the tensors to a (8,3,-1)
shape, where -1
stands for as long as it need to be, then:
d = torch.cat((a,b,c), dim=1)
e = torch.reshape(d, (8,3,-1))
I'll explain. Because the 1st dimension if different in a,b,c
the concatenation has to be along the 1st dimension, as seen in variable d
. Then, you can reshape the tensor as seen in e
where the -1
stands for "as long as it needs to be".