Search code examples
tensorflowgraphdependenciesexecutionorder-of-execution

Unexpected (random) execution order using tf.control_dependencies (tensorflow v1)


When I run the following code (tf v1.12.0), I get either 6.0 (x->mul->ident), 7.0 (x->mul->add->ident, or 9.0 (x->add->mul->ident).

Could someone please explain why the order of execution of the ops is not controlled by the tf.control_dependencies? I would think that at least add_op would be executed before anything within the control context is even considered.

tf.reset_default_graph()

x=tf.Variable(2.0)
add_op = tf.assign_add(x, 1)
mul_op = tf.assign(x, 3*x)

with tf.control_dependencies([add_op]):
    out_op = tf.identity(mul_op)

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run([out_op]))

Thanks!


Solution

  • This is because mul_op does not depend on add_op. Rather the out_op is dependent on both mul_op (as an explicit input) and add_op as a control dependency. In TensorFlow 1.x (and in TensorFlow 2.x inside a tf.Graph context), the order of the operations in Python does not affect the order of operations in the TensorFlow runtime.

    To force deterministic behavior for the example above, there are a few options.

    1. Construct the mul_op inside a tf.control_dependencies context using the add_op:
    add_op = tf.assign_add(x, 1)
    with tf.control_dependencies([add_op]):
      mul_op = tf.assign(x, 3 * x)
    
    1. Have mul_op take the output of the addition (add_op) as an input.
    add_op = tf.assign_add(x, 1)
    mul_op = tf.assign(x, 3 * add_op)
    
    1. Remove the control dependency from the identity op and call sess.run() on out_op and add_op explicitly.
    x=tf.Variable(2.0)
    add_op = tf.assign_add(x, 1)
    mul_op = tf.assign(x, 3*x)
    out_op = tf.identity(mul_op)
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)
        sess.run(add_op)
        print(sess.run(out_op))
    

    These always return 9.0.

    To dig really deep and see what the dependencies in the graph are, you could try:

    tf.get_default_graph().as_graph_def()
    

    an see what the input values to each node in the graph are.