Search code examples
tensorflowtfrecord

Reading a TFRecord file where features that were used to encode is not known


I am very new to TensorFlow and this might be a very beginner question. I have seen examples where custom datasets are converted to TFRecord files using the knowledge of the features one wants to use (for example-'image', 'label'). And while parsing this TFRecord file back, one has to know the features beforehand (i.e. 'image', 'label') in order to be able to use this dataset.

My question is- how do we parse TFRecord files where we do not know the features beforehand? Suppose someone gives me a TFRecord file and I want to decode all the associated features with this.

Some examples which I am referring to are: Link 1, Link 2


Solution

  • Here is something that might help. It's a function that goes through a records file and saves the available information about the features. You can modify it to just look at the first record and return that information, although depending on the case it may be useful to see all the records in case there are optional features only present in some of the or features with variable size.

    import tensorflow as tf
    
    def list_record_features(tfrecords_path):
        # Dict of extracted feature information
        features = {}
        # Iterate records
        for rec in tf.data.TFRecordDataset([str(tfrecords_path)]):
            # Get record bytes
            example_bytes = rec.numpy()
            # Parse example protobuf message
            example = tf.train.Example()
            example.ParseFromString(example_bytes)
            # Iterate example features
            for key, value in example.features.feature.items():
                # Kind of data in the feature
                kind = value.WhichOneof('kind')
                # Size of data in the feature
                size = len(getattr(value, kind).value)
                # Check if feature was seen before
                if key in features:
                    # Check if values match, use None otherwise
                    kind2, size2 = features[key]
                    if kind != kind2:
                        kind = None
                    if size != size2:
                        size = None
                # Save feature data
                features[key] = (kind, size)
        return features
    

    You could use it like this

    import tensorflow as tf
    
    tfrecords_path = 'data.tfrecord'
    # Make some test records
    with tf.io.TFRecordWriter(tfrecords_path) as writer:
        for i in range(10):
            example = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        # Fixed length
                        'id': tf.train.Feature(
                            int64_list=tf.train.Int64List(value=[i])),
                        # Variable length
                        'data': tf.train.Feature(
                            float_list=tf.train.FloatList(value=range(i))),
                    }))
            writer.write(example.SerializeToString())
    # Print extracted feature information
    features = list_record_features(tfrecords_path)
    print(*features.items(), sep='\n')
    # ('id', ('int64_list', 1))
    # ('data', ('float_list', None))