Search code examples
pythontensorflowkerasgeneratortf.keras

Most scalable way for using generators with tf.data ? tf.data guide says `from_generator` has limited scalability


tf.data has a from_generator initializer, it doesn't seem like it's scalable. From the official guide

Caution: While this is a convienient approach it has limited portability and scalibility. It must run in the same python process that created the generator, and is still subject to the Python GIL.

https://www.tensorflow.org/guide/data#consuming_python_generators

And in the official documentation

NOTE: The current implementation of Dataset.from_generator() uses tf.numpy_function and inherits the same constraints. In particular, it requires the Dataset- and Iterator-related operations to be placed on a device in the same process as the Python program that called Dataset.from_generator(). The body of generator will not be serialized in a GraphDef, and you should not use this method if you need to serialize your model and restore it in a different environment.

NOTE: If generator depends on mutable global variables or other external state, be aware that the runtime may invoke generator multiple times (in order to support repeating the Dataset) and at any time between the call to Dataset.from_generator() and the production of the first element from the generator. Mutating global variables or external state can cause undefined behavior, and we recommend that you explicitly cache any external state in generator before calling Dataset.from_generator().

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator

However, generators are the a fairly common method in training over very large amounts of data. So there must be some alternative best practice for this, but the official Tensorflow data guide doesn't not give any information on this.


Solution

  • Iterate through your generator and write the data to a TFRecord. Then use TFRecordDataset. This is the guide.

    https://www.tensorflow.org/tutorials/load_data/tfrecord

    TF is built to use these types of Datasets effectively with multi-gpu.

    Sharding the data to disk also improves shuffling.