Search code examples
tensorflowmachine-learningdatasetconv-neural-networktfrecord

How to load in a downloaded tfrecord dataset into TensorFlow?


I am quite new to TensorFlow, and have never worked with TFRecords before.

I have downloaded a dataset of images from online and the download format was TFRecord.

This is the file structure in the downloaded dataset:

1. enter image description here

2. enter image description here

  1. E.g. inside "test"

enter image description here

What I want to do is load in the training, validation and testing data into TensorFlow in a similar way to what happens when you load a built-in dataset, e.g. you might load in the MNIST dataset like this, and get arrays containing pixel data and arrays containing the corresponding image labels.

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

However, I have no idea how to do so.

I know that I can use dataset = tf.data.TFRecordDataset(filename) somehow to open the dataset, but would this act on the entire dataset folder, one of the subfolders, or the actual files? If it is the actual files, would it be on the .TFRecord file? How do I use/what do I do with the .PBTXT file which contains a label map?

And even after opening the dataset, how can I extract the data and create the necessary arrays which I can then feed into a TensorFlow model?


Solution

  • It's mostly archaeology, and plus a few tricks.

    1. First, I'd read the README.dataset and README.roboflow files. Can you show us what's in them?

    2. Second, pbtxt are text formatted so we may be able to understand what that file is if you just open it with a text editor. Can you show us what's in that.

    3. The think to remember about a TFRecord file is that it's nothing but a sequence of binary records. tf.data.TFRecordDataset('balls.tfrecord') will give you a dataset that yields those records in order.

    Number 3. is the hard part, because here you'll have binary blobs of data, but we don't have any clues yet about how they're encoded.

    It's common for TFRecord filed to contian serialized tf.train.Example.

    So it would be worth a shot to try and decode it as a tf.train.Example to see if that tells us what's inside.

    ref

    for record in tf.data.TFRecordDataset('balls.tfrecord'):
      break
    
    example = tf.train.Example()
    example.ParseFromString(record.numpy())
    print(example)
    

    The Example object is just a representation of a dict. If you get something other than en error there look for the dict keys and see if you can make sense out of them.

    Then to make a dataset that decodes them you'll want something like:

    def decode(record):
      return tf.train.parse_example(record, {key:tf.io.RaggedFeature(dtype) for key, dtype in key_dtypes.items()})
    
    
    ds = ds.map(decode)