Search code examples
pythontensorflowtf.data.dataset

Drop bad data from dataset Tensorflow


I have a training pipeline using tf.data. Inside the dataset there is some bad elements, in my case a values of 0. How do i drop these bad data elements based on their value? I want to be able to remove them within the pipeline while training since the dataset is large.

Assume from the following pseudo code:

def parse_function(element):
    height = element['height']
    if height <= 0: skip() #How to skip this value

    labels = element['label']
    features['height'] = height

    return features, labels

ds = tf.data.Dataset.from_tensor_slices(ds_files)
clean_ds = ds.map(parse_function)

A suggestion would be using ds.skip(1) based on the feature value, or provide some sort of neutral weight/loss?


Solution

  • You can use tf.data.Dataset.filter:

    def filter_func(elem):
        """ return True if the element is to be kept """
        return tf.math.greater(elem['height'],0)
    
    ds = tf.data.Dataset.from_tensor_slices(ds_files)
    clean_ds = ds.filter(filter_func)