Search code examples
tensorflowdistributed

How are training data 'batches' distributed to workers in Tensorflow?


I am running Distributed Tensorflow with the CIFAR10 example with up to 128 workers and 1 parameter server.

I was wondering if the FLAGS.batch_size determines the size of each batch sent to EACH worker, or if this FLAGS.batch_size determines the size of each batch sent to ALL workers?

This difference has performance implications as splitting a batch across too many workers can lead to too much communication and not enough computation.


Solution

  • The batch size in the distributed CIFAR10 example refers to the per-GPU batch size.

    (But it's a good question to ask - some of the synchronous models refer to it as the aggregate batch size!)