Search code examples
kerastensorflow2.0multi-gpu

How to use keras.utils.Sequence data generator with tf.distribute.MirroredStrategy for multi-gpu model training in tensorflow?


I want to train a model on several GPUs using tensorflow 2.0. In the tensorflow tutorial for distributed training (https://www.tensorflow.org/guide/distributed_training), the tf.data datagenerator is converted into a distributed dataset as follows:

dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)

However, I want to use my own custom data generator instead (for example, the keras.utils.Sequence datagenerator, along with keras.utils.data_utils.OrderedEnqueuer for asynchronous batch generation). But the mirrored_strategy.experimental_distribute_dataset method supports only tf.data datagenerator. How do I use the keras datagenerator instead?

Thank you!


Solution

  • I used tf.data.Dataset.from_generator with my keras.utils.sequence in the same situation, and it solved my issues!

    train_generator = SegmentationMultiGenerator(datasets, folder) # My keras.utils.sequence object
    
    def generator():
        multi_enqueuer = OrderedEnqueuer(train_generator, use_multiprocessing=True)
        multi_enqueuer.start(workers=10, max_queue_size=10)
        while True:
            batch_xs, batch_ys, dset_index = next(multi_enqueuer.get()) # I have three outputs
            yield batch_xs, batch_ys, dset_index
    
    dataset = tf.data.Dataset.from_generator(generator,
                                             output_types=(tf.float64, tf.float64, tf.int64),
                                             output_shapes=(tf.TensorShape([None, None, None, None]),
                                                            tf.TensorShape([None, None, None, None]),
                                                            tf.TensorShape([None, None])))
    
    strategy = tf.distribute.MirroredStrategy()
    
    train_dist_dataset = strategy.experimental_distribute_dataset(dataset)
    

    Note that this is my first working solution - at the moment I have found it most convenient to just put 'None' in the place of the real output shapes, which I have found to work.