Search code examples
pythontensorflowshapesdimensionstensor

How can I determine whether a intermediate results has or has no data?


How can I implement "if there exist items in a Tensor then calculate the average value of it, else assign it a certain value"? take tf.gather_nd() for example choosing some rows from source_tensor with shape (?, 2)

result = tf.gather_nd(source_tensor, indices)

should get the items from source_tensor according indices, but if indices is am empty list [], tf.gather_nd will, the program will continue and there is nothing in result.

So I wonder that is there a way to determine if the result is empty (that is it has no data) when building the computational graph of tensorflow? and if so, I want to assign it constant value manually.

Because what I'm going to do next is

tf.reduce_mean(result)

if the result has no data, tf.reduce_mean(result) will produce nan.


Solution

  • You should be able to do this via tf.cond, which executes one of two branches depending on some condition. I haven't tested the below code so please report whether it works.

    mean = tf.cond(tf.size(result), lambda: tf.reduce_mean(result), lambda: some_constant)
    

    The idea is to check whether result contains any items via tf.size (should return 0 if result is empty). You might need to convert it to a boolean condition explicitly, i.e. use tf.cast(tf.size(result), tf.bool) instead.