Search code examples
pythonreferencetensorflowdereference

How to dereference _ref tensor type in TensorFlow?


How to convert reference tensor type to value tensor type?

The only way I found is to add a zero to a tensor. Is there any convenient way?

Below assign is a tensor of reference type. How to get rid of _ref?

import tensorflow as tf

counter = tf.Variable(0, name="counter")

zero = tf.constant(0)
one = tf.constant(1)

new_counter = tf.add(counter, one)
assign = tf.assign(counter, new_counter) # dtype=int32_ref
result = tf.add(assign, zero) # dtype=int32
result2 = tf.convert_to_tensor(assign) # dtype=int32_ref
# result3 = assign.value() # has no attribute value

Solution

  • In general, you should be able to use a tf.foo_ref-type tensor anywhere a tf.foo-type tensor is expected. TensorFlow ops will implicitly dereference their input arguments (unless a reference tensor is explicitly expected, e.g. in tf.assign()).

    The simplest way to dereference a tensor is to use tf.identity(), as follows:

    counter = tf.Variable(0)
    assert counter.dtype == tf.int32_ref
    
    counter_val = tf.identity(counter)
    assert counter_val.dtype == tf.int32
    

    Note that this answers your question, but can have surprising semantics, because tf.identity() does not copy the underlying buffer. Therefore, counter and counter_val in the above example share the same buffer, and a modification to counter will be reflected in counter_val:

    counter = tf.Variable(0)
    counter_val = tf.identity(counter)  # Take alias before the `assign_add` happens.
    counter_update = counter.assign_add(1)
    
    with tf.control_dependencies([counter_update]):
      # Force a copy after the `assign_add` happens.
      result = counter_val + 0
    
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    
    print sess.run(result)  # ==> 1  (result has effect of `assign_add`)