I am very new to TensorFlow and this might be a very beginner question. I have seen examples where custom datasets are converted to TFRecord files using the knowledge of the features one wants to use (for example-'image', 'label'). And while parsing this TFRecord file back, one has to know the features beforehand (i.e. 'image', 'label') in order to be able to use this dataset.
My question is- how do we parse TFRecord files where we do not know the features beforehand? Suppose someone gives me a TFRecord file and I want to decode all the associated features with this.
Here is something that might help. It's a function that goes through a records file and saves the available information about the features. You can modify it to just look at the first record and return that information, although depending on the case it may be useful to see all the records in case there are optional features only present in some of the or features with variable size.
import tensorflow as tf
def list_record_features(tfrecords_path):
# Dict of extracted feature information
features = {}
# Iterate records
for rec in tf.data.TFRecordDataset([str(tfrecords_path)]):
# Get record bytes
example_bytes = rec.numpy()
# Parse example protobuf message
example = tf.train.Example()
example.ParseFromString(example_bytes)
# Iterate example features
for key, value in example.features.feature.items():
# Kind of data in the feature
kind = value.WhichOneof('kind')
# Size of data in the feature
size = len(getattr(value, kind).value)
# Check if feature was seen before
if key in features:
# Check if values match, use None otherwise
kind2, size2 = features[key]
if kind != kind2:
kind = None
if size != size2:
size = None
# Save feature data
features[key] = (kind, size)
return features
You could use it like this
import tensorflow as tf
tfrecords_path = 'data.tfrecord'
# Make some test records
with tf.io.TFRecordWriter(tfrecords_path) as writer:
for i in range(10):
example = tf.train.Example(
features=tf.train.Features(
feature={
# Fixed length
'id': tf.train.Feature(
int64_list=tf.train.Int64List(value=[i])),
# Variable length
'data': tf.train.Feature(
float_list=tf.train.FloatList(value=range(i))),
}))
writer.write(example.SerializeToString())
# Print extracted feature information
features = list_record_features(tfrecords_path)
print(*features.items(), sep='\n')
# ('id', ('int64_list', 1))
# ('data', ('float_list', None))