Search code examples
pythontensorflowtfrecord

Tensorflow 2.0: how to transform from MapDataset (after reading from TFRecord) to some structure that can be input to model.fit


I've stored my training and validation data on two separate TFRecord files, in which I store 4 values: signal A (float32 shape (150,)), signal B (float32 shape (150,)), label (scalar int64), id (string). My parsing function for reading is:

def _parse_data_function(sample_proto):

    raw_signal_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'id': tf.io.FixedLenFeature([], tf.string),
    }

    for key, item in SIGNALS.items():
        raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)

    # Parse the input tf.Example proto using the dictionary above.
    return tf.io.parse_single_example(sample_proto, raw_signal_description)

where SIGNALS is a dictionary mapping signal name->signal shape. Then, I read the raw datasets:

training_raw = tf.data.TFRecordDataset(<path to training>), compression_type='GZIP')
val_raw = tf.data.TFRecordDataset(<path to validation>), compression_type='GZIP')

and use map to parse the values:

training_data = training_raw.map(_parse_data_function)
val_data = val_raw.map(_parse_data_function)

Displaying the header of training_data or val_data, I get:

<MapDataset shapes: {Signal A: (150,), Signal B: (150,), id: (), label: ()}, types: {Signal A: tf.float32, Signal B: tf.float32, id: tf.string, label: tf.int64}>

which is pretty much as expected. I also checked some of the values for consistency and they seemed to be correct.

Now, to my issue: how do I get from the MapDataset, with the dictionary like structure, to something that can be given as input to the model?

The input to my model is the pair (Signal A, label), though in the future I will use Signal B as well.

The simplest way to me seemed to create an generator over the elements that I want. Something like:

def data_generator(mapdataset):
    for sample in mapdataset:
        yield (sample['Signal A'], sample['label'])

However, with this approach I lose some convenience of Datasets, such as batching, and it is also not clear how to use the same approach for the validation_data paramenter of model.fit. Ideally, I would only convert between the map representation and the Dataset representation where it iterates over pairs of Signal A tensors and labels.

EDIT: My end product should be something with a header akin to: <TensorSliceDataset shapes: ((150,), ()), types: (tf.float32, tf.int64)> But not necessarily TensorSliceDataset


Solution

  • You can simply do this in the parse function. For example:

    def _parse_data_function(sample_proto):
    
        raw_signal_description = {
            'label': tf.io.FixedLenFeature([], tf.int64),
            'id': tf.io.FixedLenFeature([], tf.string),
        }
    
        for key, item in SIGNALS.items():
            raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)
    
        # Parse the input tf.Example proto using the dictionary above.
        parsed = tf.io.parse_single_example(sample_proto, raw_signal_description)
    
        return parsed['Signal A'], parsed['label']
    

    If you map this function over the TFRecordDataset, you will have a dataset of tuples (signal_a, label) instead of a dataset of dictionaries. You should be able to put this into model.fit directly.