Search code examples
pythonpytorch

Delete rows from values from a torch tensor (drop method in pytorch)


Let's say I have a pytorch tensor

import torch
x = torch.tensor([
    [1,2,3,4],
    [5,6,7,8],
    [9,10,11,12]
    ])

And I want to delete the row with values [5,6,7,8]. I have seen this answer (which solves the problem by indexing), this one (which solves the problem by masking), this one and this one (deleting rows knowing the index).

In my case, I know the values of the tensor I want to delete, but not the index, and the values should be the same in every column of the tensor.

I could try doing the masking in this question and then indexing the rows as shown here, something like this:

ind = torch.nonzero(torch.all(x==torch.tensor([5,6,7,8]), dim=0))
x = torch.cat((x[:ind],x[ind+1:]))

That works, but I'd like a cleaner solution than splitting the tensor and concatenating it again. Something similar to the drop() method in pandas dataframes.


Solution

  • You can use the torch.all with combination of ~ NOT (inversion of bits)to exclude the column(s) that does match with the given one. It is one line code without splitting the tensor.

    import torch
    
    x = torch.tensor([[1, 2, 3 ,4],
                      [5, 6, 7 ,8],
                      [9, 10, 11 ,12],])
    
    x = x[~torch.all(x == torch.tensor([5,6,7,8]), dim=1)]
    

    The resulting tensor is as follows;

    tensor([[ 1,  2,  3,  4],
            [ 9, 10, 11, 12]])