Search code examples
tensorflowtensorflow-datasets

when trying to load external tfrecord with TFDS, given tf.train.Example, how to get tfds.features?


What I need help with / What I was wondering

Hi, I am trying to load external tfrecord files with TFDS. I have read the official doc here, and find I need to define the feature structure using tfds.features. However, since the tfrecords files are alreay generated, I do not have control the generation pipeline. I do, however, know the tf.train.Example structre used in TFRecordWriter during generation, shown as follows.

from tensorflow.python.training.training import BytesList, Example, Feature, Features, Int64List

dict(Example=Features({
'image': Feature(bytes_list=BytesList(value=[img_str])), # img_str is jpg encoded image raw bytes
'caption': Feature(bytes_list=BytesList(value=[caption])), # caption is a string
'height': Feature(bytes_list=Int64List(value=[caption])), 
'width': Feature(bytes_list=Int64List(value=[caption])), 
})

The doc only describes how to translate tfds.features into the human readable structure of the tf.train.Example. But nowhere does it mention how to translate a tf.train.Example into tfds.features, which is needed to automatically add the proper metadata fileswith tfds.folder_dataset.write_metadata.

I wonder how to translate the above tf.train.Example into tfds.features? Thanks a lot!

BTW, while I understand that it is possible to directly read the data as it is in TFRecord with tf.data.TFRecordDataset and then use map(decode_fn) for decoding as suggested here, it seems to me this approach lacks necessary metadata like num_shards or shard_lengths. In this case, I am not sure if it is still ok to use common operations like cache/repeat/shuffle/map/batch on that tf.data.TFRecordDataset. So I think it is better to stick to the tfds approach.

What I've tried so far

I have searched the official doc for quite some time but cannot find the answer. There is a Scalar class in tfds.features, which I assume could be used to decode Int64List. But How can I decode the BytesList?

Environment information

  • tensorflow-datasets version: 4.8.2
  • tensorflow version: 2.11.0

Solution

  • After some searching, I find the simplest solution is

    features = tfds.features.FeaturesDict({
        'image': tfds.features.Image(),  # << Ideally best if add the `shape=(h, w, c)` info
        'caption': tfds.features.Text(),
        'height': tf.int32,
        'width': tf.int32,
    })
    

    Then I can load the data either with tfds.folder_dataset.write_metadata or directly with:

    ds = tf.data.TFRecordDataset()
    ds = ds.map(features.deserialize_example)
    

    batch, shuffle ,... will work as expected in both cases.

    The TFDS metadata would be helpful if fine-grained split control is needed.