Is there an easy way to filter all entries containing a nan
value from a
instance? Like the dropna
method in Pandas?
Short example:
import numpy as np
import tensorflow as tf
X =[[1,2,3], [0,0,0], [np.nan,np.nan,np.nan], [3,4,5], [np.nan,3,4]])
y =[np.nan, 0, 1, 2, 3])
ds =,y))
ds = foo(ds) # foo(x) = ?
for x in iter(ds): print(str(x))
What can I use for foo(x)
to get the following output:
(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0>)
(<tf.Tensor: shape=(3,), dtype=float32, numpy=array([3., 4., 5.], dtype=float32)>, <tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
If you want to try for yourself, here is Google Colab notebook.
I had a slightly different approach than the existing answer. Rather than using sum, I'm using tf.reduce_any
filter_nan = lambda x, y: not tf.reduce_any(tf.math.is_nan(x)) and not tf.math.is_nan(y)
ds =,y)).filter(filter_nan)
[(array([0., 0., 0.], dtype=float32), 0.0),
(array([3., 4., 5.], dtype=float32), 2.0)]