Search code examples
pythontensorflowtensorflow-datasets

Using tf.while_loop in tf.data to modify the content of the array causes an error


I want to use for loop to manipulate an array in tf.data. For the current tf.while_loop method, I must match the input parameters with the output, so I created an array new_data in advance, and then used tf.while_loop modify the contents of the array sequentially, but the result is often wrong. Did I do something wrong?

code:

import tensorflow as tf
import numpy as np

def body(index, new_data):
    tf.Variable(lambda:new_data[index, 0]).assign(1)  #An error occurred
    tf.Variable(lambda:new_data[index, 0]).assign(1)
    return tf.add(index,1), new_data
    
def main(data):
    new_data = tf.zeros((200,2), dtype=tf.float64)
    index = tf.constant(0)
    condition = lambda index, new_data: tf.less(index, 200)
    r = tf.while_loop(condition, body, loop_vars=(index, new_data))
    return data
    
trainDS = tf.data.Dataset.from_tensor_slices((np.arange(100)))
trainDS = (
        trainDS
        .map(main, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(10, drop_remainder=True)
        .prefetch(tf.data.AUTOTUNE))

for i in trainDS:
    i

error:

    test2.py:13 main  *
        r = tf.while_loop(condition, body, loop_vars=(index, new_data))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py:605 new_func  **
        return func(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2499 while_loop_v2
        return_same_structure=True)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2696 while_loop
        back_prop=back_prop)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:200 while_loop
        add_control_dependencies=add_control_dependencies)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:990 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:178 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    test2.py:5 body
        tf.Variable(lambda:new_data[index, 0]).assign(1)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:262 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:237 <lambda>
        previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variable_scope.py:2667 default_variable_creator_v2
        shape=shape)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:264 __call__
        return super(VariableMetaclass, cls).__call__(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1585 __init__
        distribute_strategy=distribute_strategy)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1712 _init_from_args
        initial_value = initial_value()
    test2.py:5 <lambda>
        tf.Variable(lambda:new_data[index, 0]).assign(1)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py:1011 _slice_helper
        end.append(s + 1)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1180 binary_op_wrapper
        raise e
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1164 binary_op_wrapper
        return func(x, y, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1486 _add_dispatch
        return gen_math_ops.add_v2(x, y, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py:477 add_v2
        x, y, name=name, ctx=_ctx)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py:501 add_v2_eager_fallback
        ctx=ctx, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:75 quick_execute
        raise e
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:60 quick_execute
        inputs, attrs, num_outputs)

    TypeError: An op outside of the function building code is being passed
    a "Graph" tensor. It is possible to have Graph tensors
    leak out of the function building context by including a
    tf.init_scope in your function building code.
    For example, the following function will fail:
      @tf.function
      def has_init_scope():
        my_constant = tf.constant(1.)
        with tf.init_scope():
          added = my_constant * 2
    The graph tensor has name: while/Placeholder:0

environment: python3.6 tensorflow-gpu2.4.0 Ubuntu20.04


Solution

  • You could try using tf.tensor_scatter_nd_update for your use case, but note that the main method is currently called for each data point or batch of data points, so new_data is reinitialised each time.

    import tensorflow as tf
    import numpy as np
    
    def body(index, new_data):
        new_data = tf.tensor_scatter_nd_update(new_data, [[index, 0], [index, 0]], [1.0, 1.0])
        return tf.add(index,1), new_data
        
    def main(data):
        new_data = tf.zeros((200,2), dtype=tf.float64)
        index = tf.constant(0)
        condition = lambda index, new_data: tf.less(index, 200)
        r = tf.while_loop(condition, body, loop_vars=(index, new_data))
        return data
        
    trainDS = tf.data.Dataset.from_tensor_slices((np.arange(100)))
    trainDS = (
            trainDS
            .map(main, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(10, drop_remainder=True)
            .prefetch(tf.data.AUTOTUNE))
    
    for i in trainDS:
      print(i)