Search code examples
nantensortorchminimum

pytorch min by columns with nan


I have a 2D torch tensor with nan values, I would like to get column minimum values and ignore cells with nan values.

import torch

data = torch.tensor([[ 0.,  1., float('nan'),  3.],[ 4.,  5.,  6.,  7.], [ 8.,  9., 10., 11.]])
torch.min(data,0)

# What I would like to get is
# tensor([0., 1., 6., 3.])

Is there all suggestion? Thanks


Solution

  • You can do so by converting all the nan values in the tensor to an incredible high value and then running torch.min:

     #I am replacing nan with 10^15
     data = torch.nan_to_num(data,nan = 10e14)
     data = torch.min(data,0)
    

    torch.nan_to_num is used to turn all the nan values in your tensors to a certain value. Manipulate that as per your requirement.