Search code examples
deep-learningpytorchsimpletransformers

Horizontal stacking in Pytorch


I am trying to implement transformers and stuck at one point.

Say I have input sequence of shape [2,20] where 2 is the number of sample and 20 is the number of words in sequence ( sequence length ).

So, I create an array like [0,1,2, ... 19] of shape [1,20]. Now I want to stack it , something like the final shape should be [2,20] to be in-line with input sequence. Like below

[[0,1,2, ... 19],
[0,1,2, ... 19]]

Is there a torch function for doing so. I can loop and create the data and arrays but wanted to avoid it.


Solution

  • If the tensors you want to stack are of shape [1,20], you can use torch.cat()

    t1 = torch.zeros([1,5]) # tensor([[0., 0., 0., 0., 0.]])
    t2 = torch.ones([1,5]) # tensor([[1., 1., 1., 1., 1.]])
    
    torch.cat([t1, t2]) # tensor([[0., 0., 0., 0., 0.],
                                  [1., 1., 1., 1., 1.]])
    
    

    If the tensors are 1-D, you can simply use torch.stack()

    t1 = torch.zeros([5]) # tensor([0., 0., 0., 0., 0.])
    t2 = torch.ones([5]) # tensor([1., 1., 1., 1., 1.])
    
    torch.stack([t1, t2]) # tensor([[0., 0., 0., 0., 0.],
                                    [1., 1., 1., 1., 1.]])
    
    

    Now, for a shorter method for your case, you can do:

    torch.arange(0,20).repeat(2,1) # tensor([[0,1,2, ... 19],
                                             [0,1,2, ... 19]])