Search code examples
pythontensorflowtraining-datatfrecord

How to inspect the structure of a TFRecord file in TensorFlow 1.13?


I am rather confused by the TFRecord file format, and how to use it. I have a TFRecord, but have little idea as to what it exactly contains and what its structure is. How can I print and inspect a TFRecord and/or its TFExamples? I am essentially asking the same as this question, but the answers to that one are outdated. Printing the output_shapes, output_types or output_classes of my TFRecord tells me nothing (why?). The tf.io.tf_record_iterator() function is deprecated, but TFRecord datasets now appear themselves iterable (but then why would one still need the other iterators?). However, simply printing each iteration returns gibberish, and tf.train.Example.FromString(example) throws a TypeError: a bytes-like object is required, not 'tensorflow.python.framework.ops.EagerTensor'. It's all rather confusing. Simply initializing a tf.data.Dataset using from_tensor_slices() seems so much easier to inspect, and actually gives information on its shape and type.


Solution

  • You can use tf.python_io.tf_record_iterator to inspect tfrecords file. It creates a generato. To access single example you need to iterate over it:

    for str_rec in tf.python_io.tf_record_iterator('file.tfrecords'):
        example = tf.train.Example()
        example.ParseFromString(str_rec)
        print(dict(example.features.feature).keys())
    

    This will outputs feature names and type(bytes_list in this case)

    dict_keys(['label', 'width', 'image_raw', 'height'])
    

    To output datatypes also, you'd need

    print(dict(example.features.feature).values())
    

    But this will print raw string as well, and you can hit screen length limit.

    When you know how it was encoded you can access values by

    string = example.features.feature['image_raw'].bytes_list.value[0]
    output = np.fromstring(string, dtype)
    

    You can read more about it here https://www.tensorflow.org/tutorials/load_data/tf_records

    EDIT: IF eager mode is on, you can directly iterate over dataset object, using either numpy to decode

    for str_rec in tf.data.TFRecordDataset('file.tfrecords'):
        output = np.fromstring(str_rec.numpy(), dtype))
    

    or native TF. tf.io.decode_raw(str_rec, tf.uint8))

    However, this will give you a flatten array, which will not carry any information about sizes of image dimensions for example