Search code examples
pythonpytorch

Slice nested pytorch tensor


Is there a nice way, maybe with slicing to remove the middle 10 elements from the each tensor? So this

tensor([[ 0.1585, -0.1414, -0.0166,  0.1008, -0.0693,  0.1748,  0.1587, -0.0262,
         -0.1192,  0.0404,  0.0661, -0.1363, -0.1636,  0.0900,  0.1460, -0.1348,
          0.1293,  0.0736,  0.1186,  0.0339, -0.1225,  0.1066,  0.1314, -0.0322,
         -0.1445,  0.0966, -0.1792,  0.0682,  0.0310,  0.1446],
        [ 0.0143, -0.0800,  0.0851,  0.1199, -0.1364, -0.0677, -0.0623, -0.0921,
          0.0284, -0.0700, -0.1295, -0.1681, -0.1670,  0.0758, -0.1469,  0.0280,
          0.0460, -0.1545, -0.1377,  0.1204, -0.0134, -0.0046,  0.1248, -0.1202,
         -0.1177, -0.0598,  0.1648,  0.0955,  0.1262, -0.0785]]) 

would turn into this?

tensor([[ 0.1585, -0.1414, -0.0166,  0.1008, -0.0693,  0.1748,  0.1587, -0.0262,
         -0.1192,  0.0404, -0.1225,  0.1066,  0.1314, -0.0322,
         -0.1445,  0.0966, -0.1792,  0.0682,  0.0310,  0.1446],
        [ 0.0143, -0.0800,  0.0851,  0.1199, -0.1364, -0.0677, -0.0623, -0.0921,
          0.0284, -0.0700, -0.0134, -0.0046,  0.1248, -0.1202,
         -0.1177, -0.0598,  0.1648,  0.0955,  0.1262, -0.0785]]) 

Solution

  • I think there can be multiple ways of doing it. One way would be to use slicing like this:

    # I've hardcoded 10, 20 but you can make them variables as well
    idx = torch.cat((torch.arange(0,10 ), torch.arange(20,qwe.shape[1])))
    qwe[:,idx]
    

    Output:

    tensor([[ 0.1585, -0.1414, -0.0166,  0.1008, -0.0693,  0.1748,  0.1587, -0.0262,
             -0.1192,  0.0404, -0.1225,  0.1066,  0.1314, -0.0322, -0.1445,  0.0966,
             -0.1792,  0.0682,  0.0310,  0.1446],
            [ 0.0143, -0.0800,  0.0851,  0.1199, -0.1364, -0.0677, -0.0623, -0.0921,
              0.0284, -0.0700, -0.0134, -0.0046,  0.1248, -0.1202, -0.1177, -0.0598,
              0.1648,  0.0955,  0.1262, -0.0785]])