Search code examples
tensorflowtensortensorflow-datasetstensorflow-estimator

how to feed Tensorflow dataset pipeline with an example of triple tensors


As it is described in official Tensorflow website, we can feed a dataset pipeline with examples of pair tensor (input, label). I need to know how can I add one more item like ( input, lable1, lable2)?


Solution

  • Simple!

    You just make your dataset method output dictionary instead.

    This code is from the link you posted, all the way down to the bottom.

    def dataset_input_fn():
      filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
      dataset = tf.data.TFRecordDataset(filenames)
    
      # Use `tf.parse_single_example()` to extract data from a `tf.Example`
      # protocol buffer, and perform any additional per-record preprocessing.
      def parser(record):
        keys_to_features = {
            "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
            "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
            "label": tf.FixedLenFeature((), tf.int64,
                                        default_value=tf.zeros([], dtype=tf.int64)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)
    
        # Perform additional preprocessing on the parsed data.
        image = tf.image.decode_jpeg(parsed["image_data"])
        image = tf.reshape(image, [299, 299, 1])
        label = tf.cast(parsed["label"], tf.int32)
    
        return {"image_data": image, "date_time": parsed["date_time"]}, label
    
      # Use `Dataset.map()` to build a pair of a feature dictionary and a label
      # tensor for each example.
      dataset = dataset.map(parser)
      dataset = dataset.shuffle(buffer_size=10000)
      dataset = dataset.batch(32)
      dataset = dataset.repeat(num_epochs)
      iterator = dataset.make_one_shot_iterator()
    
      # `features` is a dictionary in which each value is a batch of values for
      # that feature; `labels` is a batch of labels.
      features, labels = iterator.get_next()
      return features, labels
    

    Now, features is actually a dictionary with the fields image_data and date_time. This way, you can add as much as you want to either features or labels, while still sticking to the two outputs.