Search code examples
pythonc++tensorflowprotocol-bufferstfrecord

Generating TFRecord format data from C+


I'm trying to use TFRecord format to record data from C++ and then use it in python to feed TensorFlow model.

TLDR; Simply serializing proto messages into a stream doesn't satisfy .tfrecord format requirements of Python TFRecordDataset class. Is there an equivalent of Python TfRecordWriter in C++ (either in TensorFlow or in Google Protobuf libraries) to generate proper .tfrecord data?

Details:

The simplified C++ code looks like this:

tensorflow::Example sample;
sample.mutable_features()->mutable_feature()->operator[]("a").mutable_float_list()->add_value(1.0);

std::ofstream out;
out.open("cpp_example.tfrecord", std::ios::out | std::ios::binary);
sample.SerializeToOstream(&out);

In Python, to create a TensorFlow data I'm trying to use TFRecordDataset, but apparently it expects extra header/footer information in the .tfrecord file (rather than simple list of serialized proto messages):

import tensorflow as tf
tfrecord_dataset = tf.data.TFRecordDataset(filenames="cpp_example.tfrecord")
next(tfrecord_dataset.as_numpy_iterator())

output:

tensorflow.python.framework.errors_impl.DataLossError: corrupted record at 0 [Op:IteratorGetNext]

Note that there is nothing wrong with the recorded binary file, as following code prints a valid output:

import tensorflow as tf
p = open("cpp_example.tfrecord", "rb")
example = tf.train.Example.FromString(p.read())

output:

features {
  feature {
    key: "a"
    value {
      float_list {
        value: 1.0
      }
    }
  }
}

By analyzing the binary output generated by my C++ example, and an output generated by using Python TfRecordWriter, I observed additional header and footer bytes in the content. Unfortunately, what do these extra bytes represent was an implementation detail (probably compression type and some extra info) and I couldn't track it deeper than some class in python libraries which just exposed the interface from _pywrap_tfe.so.

There was this advice saying that .tfrecord is just a normal google protobuf data. It might be I'm missing the knowledge where to find protobuf data writer (expect serializing proto messages into the output stream)?


Solution

  • It turns out tensorflow::io::RecordWriter class of TensorFlow C++ library does the job.

    #include <tensorflow/core/lib/io/record_writer.h>
    
    #include <tensorflow/core/platform/default/posix_file_system.h>
    #include <tensorflow/core/example/example.pb.h>
    
    // ...
    
    // Create WritableFile and instantiate RecordWriter.
    tensorflow::PosixFileSystem posixFileSystem;
    std::unique_ptr<tensorflow::WritableFile> writableFile;
    
    posixFileSystem.NewWritableFile("cpp_example.tfrecord", &writableFile);
    
    tensorflow::io::RecordWriter recordWriter(mWritableFile.get(), tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions(""));
    
    // ...
    tensorflow::Example sample;
    
    // ...
    
    // Serialize proto message into a buffer and record in tfrecord format.
    std::string buffer;
    sample.SerializeToString(&buffer);
    recordWriter.WriteRecord(buffer);
    
    

    It would be helpful if this class is referenced from somewhere in TFRecord documentation.