Search code examples
pythontensorflowtensorflow-datasetstfrecord

How to replace tfds dataset with existing sharded tfrecords


I working with a cloned code that uses tfds datasets and would like to adapt it to a pre-existing set of sharded tfrecrods with as minimal modifications as possible.

Specifically, the cloned code does the following:

builder = tfds.builder(dataset, data_dir)
builder.download_and_prepare()
...
estimator.train(
        data_lib.build_input_fn(builder, True), max_steps=train_steps
)

In this code, 'dataset' is the name of a tfds dataset (e.g. cifar10 or others). Instead, I would like to train on an external dataset that is already in sharded tfrecords form, i.e.:

'train_<shard_id>-<no_samples>.tfrecords'

'val_<shard_id>-<no_samples>.tfrecords'

and resides in a bucket (on google cloud if that info helps).

I've been looking into Adding new datasets in TFDS format, but this seems like it requires a whole pipeline for generating the tfrecords from scratch, which is not possible and seems redundant given that tfrecords already exist. I'm sure that I'm missing some simple adaptation for existing tfrecords..

Any advice would be much appreciated.


Solution

  • Alona,

    Your expectation is correct: there is a special function tf.data.TFRecordDataset for working with data in tfrecords. Use it in your input_fn like this:

    def input_fn(features, labels, training=True, batch_size=256):
        
        file_paths = [file0, file1]  # pass tfrecords filenames here
        dataset = tf.data.TFRecordDataset(file_paths)
    
        # Shuffle and repeat if you are in training mode.
        if training:
            dataset = dataset.shuffle(1000).repeat()
        
        return dataset.batch(batch_size)
    

    read more at TF site: 1 2