I have a training pipeline using tf.data. Inside the dataset there is some bad elements, in my case a values of 0. How do i drop these bad data elements based on their value? I want to be able to remove them within the pipeline while training since the dataset is large.
Assume from the following pseudo code:
def parse_function(element):
height = element['height']
if height <= 0: skip() #How to skip this value
labels = element['label']
features['height'] = height
return features, labels
ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.map(parse_function)
A suggestion would be using ds.skip(1) based on the feature value, or provide some sort of neutral weight/loss?
You can use tf.data.Dataset.filter
:
def filter_func(elem):
""" return True if the element is to be kept """
return tf.math.greater(elem['height'],0)
ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.filter(filter_func)