Search code examples

Example of tf.Estimator with model parallel execution

I am currently experimenting with distributed tensorflow. I am using the tf.estimator.Estimator class (custom model function) together with tf.contrib.learn.Experiment and managed it to get a working data parallel execution.

However, I would now like to try model parallel execution. I was not able to find any example for that, except Implementation of model parallelism in tensorflow. But I am not sure how to implement this using tf.estimators (e.g. how to deal with the input functions?).

Does anybody have any experience with it or can provide a working example?


  • First up, you should stop using tf.contrib.learn.Estimator in favor of tf.estimator.Estimator, because contrib is an experimental module, and classes that have graduated to the core API (such es Estimator) automatically get deprecated.

    Now, back to your main question, you can create a distributed model and pass it via model_fn parameter of tf.estimator.Estimator.__init__.

    def my_model(features, labels, mode):
      net = features[X_FEATURE]
      with tf.device('/device:GPU:1'):
        for units in [10, 20, 10]:
          net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
          net = tf.layers.dropout(net, rate=0.1)
      with tf.device('/device:GPU:2'):
        logits = tf.layers.dense(net, 3, activation=None)
        onehot_labels = tf.one_hot(labels, 3, 1, 0)
        loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, 
      optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
      train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
      return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
    classifier = tf.estimator.Estimator(model_fn=my_model)

    The model above defines 6 layers with /device:GPU:1 placement and 3 other layers with /device:GPU:2 placement. The return value of my_model function should be an EstimatorSpec instance. A complete working example can be found in tensorflow examples.