Search code examples
tensorflowdistributed

Distributed Tensorflow: Synchronous training stalls indefinitely


I have a distributed setup of one ps task server, and two worker task servers. Each running on CPU. I've run the following example asynchronously, but it doesn't work synchronously. I'm not sure if I'm doing anything wrong with the code:

import math
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Flags for defining the tf.train.ClusterSpec
tf.app.flags.DEFINE_string("ps_hosts", "",
                           "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "",
                           "Comma-separated list of hostname:port pairs")

# Flags for defining the tf.train.Server
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_string("data_dir", "/tmp/mnist-data",
                           "Directory for storing mnist data")
tf.app.flags.DEFINE_integer("batch_size", 3, "Training batch size")

FLAGS = tf.app.flags.FLAGS

IMAGE_PIXELS = 28

steps = 1000

def main(_):
  ps_hosts = FLAGS.ps_hosts.split(",")
  worker_hosts = FLAGS.worker_hosts.split(",")

  # Create a cluster from the parameter server and worker hosts.
  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})

  # Create and start a server for the local task.
  server = tf.train.Server(cluster,
                           job_name=FLAGS.job_name,
                           task_index=FLAGS.task_index)

  tf.logging.set_verbosity(tf.logging.DEBUG)
  if FLAGS.job_name == "ps":
    server.join()
  elif FLAGS.job_name == "worker":

    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):

      with tf.name_scope('Input'):
        x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="X")
        y_ = tf.placeholder(tf.float32, [None, 10], name="LABELS")

      W = tf.Variable(tf.zeros([IMAGE_PIXELS * IMAGE_PIXELS, 10]), name="W")
      b = tf.Variable(tf.zeros([10]), name="B")
      y = tf.matmul(x, W) + b
      y = tf.identity(y, name="Y")

      with tf.name_scope('CrossEntropy'):
        cross_entropy = tf.reduce_mean(
          tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

      global_step = tf.Variable(0, name="STEP")

      with tf.name_scope('Train'):
        opt = tf.train.GradientDescentOptimizer(0.5)
        opt = tf.train.SyncReplicasOptimizer(opt, 
                                replicas_to_aggregate=2,
                                total_num_replicas=2, 
                                name="SyncReplicasOptimizer")
        train_step = opt.minimize(cross_entropy, global_step=global_step)

      with tf.name_scope('Accuracy'):
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

      saver = tf.train.Saver()
      summary_op = tf.summary.merge_all()

#      init_op = tf.initialize_all_variables()
      init_op = tf.global_variables_initializer()

    # Create a "supervisor", which oversees the training process.
    sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                             logdir="/tmp/train_logs",
                             init_op=init_op,
                             summary_op=summary_op,
                             saver=saver,
                             global_step=global_step,
                             save_model_secs=600)

    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

    config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=True,
        device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])

    # The supervisor takes care of session initialization, restoring from
    # a checkpoint, and closing when done or an error occurs.
    with sv.managed_session(server.target, config=config) as sess:
      # Loop until the supervisor shuts down or 1000000 steps have completed.
      writer = tf.summary.FileWriter("~/tensorboard_data", sess.graph)
      step = 0
      while not sv.should_stop() and step < steps:
        print("Starting step %d" % step)
        # Run a training step asynchronously.
        # See `tf.train.SyncReplicasOptimizer` for additional details on how to
        # perform *synchronous* training.

        old_step = step

        batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
        train_feed = {x: batch_xs, y_: batch_ys}

        _, step = sess.run([train_step, global_step], feed_dict=train_feed)

#        if step % 2 == 0: 
        print ("Done step %d, next step %d\n" % (old_step, step))

      # Test trained model
      print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

    # Ask for all the services to stop.
    sv.stop()

if __name__ == "__main__":
  tf.app.run()

The ps task prints this:

I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> localhost:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> TF2:2222, 1 -> TF0:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222

While the workers printed somethings similar, and then some info:

I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job ps -> {0 -> TF1:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:200] Initialize GrpcChannelCache for job worker -> {0 -> TF2:2222, 1 -> localhost:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:221] Started server with target: grpc://localhost:2222
INFO:tensorflow:SyncReplicasV2: replicas_to_aggregate=2; total_num_replicas=2
[...]
I tensorflow/core/common_runtime/simple_placer.cc:841] Train/gradients/CrossEntropy/Mean_grad/Prod_1: (Prod)/job:worker/replica:0/task:1/cpu:0
: /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Sub_2/y: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/concat_1/axis: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/concat_1/values_0: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Slice_1/size: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Sub_1/y: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Rank_2: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/concat/axis: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/concat/values_0: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Slice/size: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Sub/y: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Rank_1: (Const): /job:worker/replica:0/task:1/cpu:0
CrossEntropy/Rank: (Const): /job:worker/replica:0/task:1/cpu:0
zeros_1: (Const): /job:worker/replica:0/task:1/cpu:0
GradientDescent/value: (Const): /job:ps/replica:0/task:0/cpu:0
Fill/dims: (Const): /job:ps/replica:0/task:0/cpu:0
zeros: (Const): /job:worker/replica:0/task:1/cpu:0
Input/LABELS: (Placeholder): /job:worker/replica:0/task:1/cpu:0
Input/X: (Placeholder): /job:worker/replica:0/task:1/cpu:0
init_all_tables: (NoOp): /job:ps/replica:0/task:0/cpu:0
group_deps/NoOp: (NoOp): /job:ps/replica:0/task:0/cpu:0
report_uninitialized_variables/boolean_mask/strided_slice_1: (StridedSlice): /job:ps/replica:0/task:0/cpu:0
report_uninitialized_variables/boolean_mask/strided_slice: (StridedSlice): /job:ps/replica:0/task:0/cpu:0
[...]
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Slice_1/size: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Sub_1/y: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Rank_2: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/concat/axis: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/concat/values_0: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Slice/size: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Sub/y: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Rank_1: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] CrossEntropy/Rank: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] zeros_1: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] GradientDescent/value: (Const)/job:ps/replica:0/task:0/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] Fill/dims: (Const)/job:ps/replica:0/task:0/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] zeros: (Const)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] Input/LABELS: (Placeholder)/job:worker/replica:0/task:1/cpu:0
I tensorflow/core/common_runtime/simple_placer.cc:841] Input/X: (Placeholder)/job:worker/replica:0/task:1/cpu:0

At this point not much else happens. I tried different configurations for SyncReplicasOptimizer, but nothing seems to work.

Any help would be very appreciated!

Edit: Commands used from the command line. For the ps server and workers respectively (different task_index for workers):

python filename.py --ps_hosts=server1:2222 --worker_hosts=server2:2222,server3:2222 --job_name=ps --task_index=0
python filename.py --ps_hosts=server1:2222 --worker_hosts=server2:2222,server3:2222 --job_name=worker --task_index=0

Solution

  • While looking at other synchronous distributed tensorflow examples, I found out some pieces of tensorflow that made the code work. Specifically (after train_step):

    if (FLAGS.task_index == 0): # is chief?
        # Initial token and chief queue runners required by the sync_replicas mode
        chief_queue_runner = opt.get_chief_queue_runner()
        init_tokens_op = opt.get_init_tokens_op()
    

    and (inside of the session before the loop):

    if (FLAGS.task_index == 0): # is chief?
        # Chief worker will start the chief queue runner and call the init op
        print("Starting chief queue runner and running init_tokens_op")
        sv.start_queue_runners(sess, [chief_queue_runner])
        sess.run(init_tokens_op)
    

    So, it wasn't enough to wrap the optimizer with SyncReplicaOptimizer, but also to create and use the queue_runner and init_tokens_op. I'm not sure why this worked, but I hope this helps someone else.