Search code examples
machine-learningtensorflowcontrol-flow

How to use the function merge and switch of tensorflow?


The merge and switch may not be open to use for general users. And I have searched the source code:

There is a description in merge:

Returns the value of an available element of inputs.

What does it mean available? Is it returned by switch? This is a demo:

from tensorflow.python.ops import control_flow_ops

x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
y = control_flow_ops.merge([x_0, x_1, x_2, x_3])
with tf.Session() as sess:
    print(sess.run(y))

Solution

  • switch

    Let's start by examining the control_flow_ops.switch function:

    x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
    x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
    with tf.Session() as sess:
      print(sess.run(x_0))    # prints 2
      print(sess.run(x_3))    # prints 7
    

    control_flow_ops.switch returns a tuple of tensors, but only one of them will have a value (depending on the condition argument). In the example above, it's x_0 = 2 from the first switch and x_3 = 7 from the second one. An attempt to evaluate x_1 or x_2 will result in Retval does not have value error:

      sess.run(x_1)  # FAILS!
      sess.run(x_2)  # FAILS!
    

    In other words, x_0 and x_3 are available, while x_1 or x_2 aren't.

    merge

    control_flow_ops.merge performs an inverse op: given a tuple of tensors, it selects the available one. Precisely, it returns a named tuple ["output", "value_index"] of a tensor that has a value. According to the current doc, the input should contain exactly one available tensor, this means that your demo is strictly speaking unsupported and leads to undefined behavior. Here's an example:

    with tf.Session() as sess:
      print(sess.run(merge([x_0, x_1])))       # Merge(output=2, value_index=0)
      print(sess.run(merge([x_1, x_0])))       # Merge(output=2, value_index=1)
      print(sess.run(merge([x_2, x_3])))       # Merge(output=7, value_index=1)
      print(sess.run(merge([x_3, x_2])))       # Merge(output=7, value_index=0)
      print(sess.run(merge([x_0, x_1, x_2])))  # Merge(output=2, value_index=0)
      print(sess.run(merge([x_1, x_2, x_3])))  # Merge(output=7, value_index=2)
    

    Both of these functions can be handy to control computation flow, e.g. control_flow_ops.switch gradient is implemented through switch itself (tensorflow source code).