Search code examples
pythonpytorch

Pytorch Operation to detect NaNs


Is there a Pytorch-internal procedure to detect NaNs in Tensors? Tensorflow has the tf.is_nan and the tf.check_numerics operations ... Does Pytorch have something similar, somewhere? I could not find something like this in the docs...

I am looking specifically for a Pytorch internal routine, since I would like this to happen on the GPU as well as on the CPU. This excludes numpy - based solutions (like np.isnan(sometensor.numpy()).any()) ...


Solution

  • You can always leverage the fact that nan != nan:

    >>> x = torch.tensor([1, 2, np.nan])
    tensor([  1.,   2., nan.])
    >>> x != x
    tensor([ 0,  0,  1], dtype=torch.uint8)
    

    With pytorch 0.4 there is also torch.isnan:

    >>> torch.isnan(x)
    tensor([ 0,  0,  1], dtype=torch.uint8)