Search code examples
pythontensorflowkerasfiltertfrecord

How to skip some records from tf.data.TFRecordDataset based on their labels?


Currently I am working on a training application for ResNet. The project itself was not begun by me, but by other developers. I have a huge storage full of source photos, thousands of files, 8 TB in total. All photos are divided by hundreds of classes. Sometimes I need to include only one part of classes, other time I need to exclude them and use other classes, and so on. Every time I need to build a dataset at first, and as a result of the dataset creation I get a bunch of tfrecord files, then I can start a model training. I want to optimize this process: build a large dataset with all the photos and all the classes and after that filter (exclude/include) classes at the training time by it's labels.

I have found a filter() method and I understand how to use it with a simple dataset:

def filter(x):
    if x < -2:
        print('x < -2')
        return True
    elif x > 2:
        print('x > 2')
        return True
    else:
        print('False')
        return False

d = tf.data.Dataset.from_tensor_slices([-4, -3, -2, -1, 0, 1, 2, 3, 4])
d = d.filter(filter)

The first interesting question for me here is that it seems TF somehow optimizes a function. Output:

x < -2
x > 2
False

So not every print statement is executed. But the filter works as expected and I get [-4, -3, 3, 4] `

Anyway I cannot understand how to filter a dataset by a class label. The dataset consists of records, each record consists of two images and a class label. Here is my testing code:

def test_filter_function(x):
    feature_description = {
        'image_raw': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'image2_raw': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'label': tf.io.FixedLenFeature([], tf.string, default_value='')
    }
    print(x)
    parsed = tf.io.parse_single_example(x, feature_description)
    print(parsed['image_raw'])
    print(parsed['label'])
    return True

dataset = tf.data.TFRecordDataset(list_of_tfrecord_files)
dataset = dataset.filter(test_filter_function)

And the output is:

Tensor("args_0:0", shape=(), dtype=string)
Tensor("ParseSingleExample/ParseSingleExample:1", shape=(), dtype=string)
Tensor("ParseSingleExample/ParseSingleExample:2", shape=(), dtype=string)

What is "args_0:0"? Why both, parsed['image_raw'] and print(parsed['label']) have the same Tensor type and it's dtype is a string? How can I get a class label in a form of python string type and how to check is it in a set of enabled classes or not? Is it possible and how to exclude records which labels not in the set?


Solution

  • The first interesting question for me here is that it seems TF somehow optimizes a function.

    So in eager execution, TensorFlow would only run the function the first when it generates an optimised computation graph, after which it only uses TensorFlow functions. See Here

    What is "args_0:0"? Why both, parsed['image_raw'] and print(parsed['label']) have the same Tensor type and it's dtype is a string? How can I get a class label in a form of python string type and how to check is it in a set of enabled classes or not? Is it possible and how to exclude records which labels not in the set?

    "args_0:0" is just the tensor's generated name. Images are serialised as strings in records, you can convert them back using image.numpy(), similarly for the label you can use byteslist to decode back. See Here.

    So, if you want to filter, you need to make sure that you either wrap your function with tf.py_funcion or even better, only use tf.functions, such that your code can be run with eager execution.