Search code examples
pytorchtorch

Extract sub tensor in PyTorch


For this tensor is PyTorch,

tensor([[ 0.7646,  0.5573,  0.4000,  0.2188,  0.7646,  0.5052,  0.2042,  0.0896,
          0.7667,  0.5938,  0.3167,  0.0917],
        [ 0.4271,  0.1354,  0.5000,  0.1292,  0.4260,  0.1354,  0.4646,  0.0917,
         -1.0000, -1.0000, -1.0000, -1.0000],
        [ 0.7208,  0.5656,  0.3000,  0.1688,  0.7177,  0.5271,  0.1521,  0.0667,
          0.7198,  0.5948,  0.2438,  0.0729],
        [ 0.6292,  0.8250,  0.4000,  0.2292,  0.6271,  0.7698,  0.2083,  0.0812,
          0.6281,  0.8604,  0.3604,  0.0917]], device='cuda:0')

How can I extract to new Tensor for those values

0.7646,  0.5573,  0.4000,  0.2188
0.4271,  0.1354,  0.5000,  0.1292

How to get the first 4 of two rows into a new tensor?


Solution

  • Actually the question was answered from @zihaozhihao in the Comments but in case you are wondering where that comes from it would be helpful if you structured your Tensor like this:

    x = torch.Tensor([
            [ 0.7646,  0.5573,  0.4000,  0.2188,  0.7646,  0.5052,  0.2042,  0.0896, 0.7667,  0.5938,  0.3167,  0.0917],
            [ 0.4271,  0.1354,  0.5000,  0.1292,  0.4260,  0.1354,  0.4646,  0.0917, -1.0000, -1.0000, -1.0000, -1.0000],
            [ 0.7208,  0.5656,  0.3000,  0.1688,  0.7177,  0.5271,  0.1521,  0.0667, 0.7198,  0.5948,  0.2438,  0.0729],
            [ 0.6292,  0.8250,  0.4000,  0.2292,  0.6271,  0.7698,  0.2083,  0.0812, 0.6281,  0.8604,  0.3604,  0.0917]
    
                    ])
    

    so now it is more clear that you have a shape (4, 12) you can think about it like an excel file, you have 4 rows and 12 columns. Now what you want is to extract from the two first rows the 4 first columns and that's why your solution would be:

    x[:2, :4] # 2 means you want to take all the rows until the second row and then you set that you want all the columns until the fourth column, this Code will also give the same result x[0:2, 0:4]