Search code examples
pythonpytorchtorchtensor

Deleting Rows in Torch Tensor


I have a torch tensor as follows -

a = tensor(
[[0.2215, 0.5859, 0.4782, 0.7411],
[0.3078, 0.3854, 0.3981, 0.5200],
[0.1363, 0.4060, 0.2030, 0.4940],
[0.1640, 0.6025, 0.2267, 0.7036],
[0.2445, 0.3032, 0.3300, 0.4253]],  dtype=torch.float64)

If the first value of each row is less than 0.2 then the whole row needs to be deleted. Thus I need the output like -

tensor(
[[0.2215, 0.5859, 0.4782, 0.7411],
[0.3078, 0.3854, 0.3981, 0.5200],
[0.2445, 0.3032, 0.3300, 0.4253]],  dtype=torch.float64)

I have tried to loop through the tensor and append the valid value to a new empty tensor but was not successful. Is there any way to get the results efficiently?


Solution

  • Code

    a = torch.Tensor(
        [[0.2215, 0.5859, 0.4782, 0.7411],
        [0.3078, 0.3854, 0.3981, 0.5200],
        [0.1363, 0.4060, 0.2030, 0.4940],
        [0.1640, 0.6025, 0.2267, 0.7036],
        [0.2445, 0.3032, 0.3300, 0.4253]])
    
    y = a[a[:, 0] > 0.2]
    print(y)
    

    Output

    tensor([[0.2215, 0.5859, 0.4782, 0.7411],
            [0.3078, 0.3854, 0.3981, 0.5200],
            [0.2445, 0.3032, 0.3300, 0.4253]])