Search code examples
tensorflowtensorflow2.0tf.kerastensorflow-federated

'Attempting to capture an EagerTensor without building a function' Error: While building Federated Averaging Process


I am getting 'Attempting to capture an EagerTensor without building a function' error while trying to build my federated averaging process. I have tried all remedies for compatibility of v1 & v2 given in other other similar stack overflow questions, viz., using tf.compat.v1.enable_eager_execution() , tf.disable_v2_behaviour(), etc. But, nothing worked. My revelvant code extract is given below. My complete code in a Python notebook is given here https://gist.github.com/aksingh2411/60796ee58c88e0c3f074c8909b17b5a1.

#Making a Tensorflow Model
from tensorflow import keras

def create_keras_model():
 return tf.keras.models.Sequential([
  hub.KerasLayer(encoder, input_shape=[],dtype=tf.string,trainable=True),
  keras.layers.Dense(32, activation='relu'),
  keras.layers.Dense(16, activation='relu'),
  keras.layers.Dense(1, activation='sigmoid'),
])

def model_fn():
# We _must_ create a new model here, and _not_ capture it from an external
# scope. TFF will call this within different graph contexts.
keras_model = create_keras_model()
return tff.learning.from_keras_model(
  keras_model,
  input_spec=preprocessed_example_dataset.element_spec,
  loss=tf.keras.losses.BinaryCrossentropy(),
  metrics=[tf.keras.metrics.Accuracy()])

# Building the Federated Averaging Process
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-23-68fa27e65b7e> in <module>()
  3     model_fn,
  4     client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
  -->5     server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

 9 frames
 /usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in 
wrapper(*args, **kwargs)
263       except Exception as e:  # pylint:disable=broad-except
264         if hasattr(e, 'ag_error_metadata'):
--> 265           raise e.ag_error_metadata.to_exception(e)
266         else:
267           raise

RuntimeError: in user code:

/usr/local/lib/python3.6/dist-packages/tensorflow_hub/keras_layer.py:222 call  *
    result = f()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/saved_model/load.py:486 _call_attribute  **
    return instance.__call__(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:580 __call__
    result = self._call(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py:618 _call
    results = self._stateful_fn(*args, **kwds)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:2420 __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:1665 _filtered_call
    self.captured_inputs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:1760 _call_flat
    flat_outputs = forward_function.call(ctx, args_with_tangents)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py:627 call
    executor_type=executor_type)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/functional_ops.py:1148 partitioned_call
    args = [ops.convert_to_tensor(x) for x in args]
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/functional_ops.py:1148 <listcomp>
    args = [ops.convert_to_tensor(x) for x in args]
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py:1307 convert_to_tensor
    raise RuntimeError("Attempting to capture an EagerTensor without "

RuntimeError: Attempting to capture an EagerTensor without building a function.

Solution

  • This looks like tensors are being created outside and later being captured by model_fn. The comment inside model_fn() is related here:

    # We _must_ create a new model here, and _not_ capture it from an external scope. TFF 
    # will call this within different graph contexts.
    

    TensorFlow doesn't allow referencing tensors created in different graphs (or tf.function), so we must construct everything that will be referenced insdie model_fn() (or within the inner create_keras_model()).

    To go about finding where the errant tensor is being created it can be useful to examine the stack trace. The first line of the stack trace seems to indicate tensorflow_hub:

    /usr/local/lib/python3.6/dist-packages/tensorflow_hub/keras_layer.py:222 call  *
        result = f()
    

    The place in the source code that immediately appears to use TF Hub is the first layer of the tf.kears.Sequential construction:

    def create_keras_model():
      return tf.keras.models.Sequential([
        hub.KerasLayer(encoder, input_shape=[],dtype=tf.string,trainable=True),
        …
    

    It seems like this function may be "closing over" or "capturing" the value for encoder, which in turn may have tensors that were created in different contexts. Is it possible to move the construction of encoder to inside create_keras_model()?