Search code examples
pythontensorflowboolean-operations

tensorflow: check if a scalar boolean tensor is True


I want to control the execution of a function using a placeholder, but keep getting an error "Using a tf.Tensor as a Python bool is not allowed". Here is the code that produces this error:

import tensorflow as tf
def foo(c):
  if c:
    print('This is true')
    #heavy code here
    return 10
  else:
    print('This is false')
    #different code here
    return 0

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()

I changed if c to if c is not None without luck. How can I control foo by turning on and off the placeholder a then?

Update: as @nessuno and @nemo point out, we must use tf.cond instead of if..else. The answer to my question is to re-design my function like this:

import tensorflow as tf
def foo(c):
  return tf.cond(c, func1, func2)

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close() 

Solution

  • You have to use tf.cond to define a conditional operation within the graph and change, thus, the flow of the tensors.

    import tensorflow as tf
    
    a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
    b = tf.cond(tf.equal(a, tf.constant(True)), lambda: tf.constant(10), lambda: tf.constant(0))
    sess = tf.InteractiveSession()
    res = sess.run(b, feed_dict = {a: True})
    sess.close()
    print(res)
    

    10