Search code examples
pythonkerastensorflow-federated

get 'TypeError: Caught exception' for using 'accuracy' in Tensorflow Federated


This is my model, and I have implemented it once in TensorFlow.

def create_compiled_keras_model():

    inputs = Input(shape=(7, 20, 1))
    l0_c = Conv2D(32, kernel_size=(7, 7), padding='valid', activation='relu')(inputs)
    l1_c = Conv2D(32, kernel_size=(1, 5), padding='same', activation='relu')(l0_c)
    l1_p = AveragePooling2D(pool_size=(1, 2), strides=2, padding='same')(l1_c)
    l2_c = Conv2D(64, kernel_size=(1, 4), padding='same', activation='relu')(l1_p)
    l2_p = AveragePooling2D(pool_size=(1, 2), strides=2, padding='same')
    l3_c = Conv2D(2, kernel_size=(1, 1), padding='valid', activation='sigmoid')(l2_p)
    predictions = Flatten()(l3_c)
    predictions = tf.cast(predictions, dtype='float32')
    model = Model(inputs=inputs, outputs=predictions)
    opt = Adam(lr=0.0005)
    print(model.summary())
    def loss_fn(y_true, y_pred):
        return tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_pred, y_true))
    model.compile(optimizer=opt,
                  loss=loss_fn,
                  metrics=['accuracy'])
    return model

I get this error in TensorFlow Federated.

Traceback (most recent call last):
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 270, in report
    keras_metric = metric_type.from_config(metric_config)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py", line 594, in from_config
    return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'fn'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/fl.py", line 203, in <module>
    quantization_part = FedAvgQ.build_federated_averaging_process(model_fn)
  File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/new_fedavg_keras.py", line 195, in build_federated_averaging_process
    stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py", line 351, in build_model_delta_optimizer_process
    dummy_model_for_metadata = model_utils.enhance(model_fn())
  File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/fl.py", line 196, in model_fn
    return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 216, in from_compiled_keras_model
    return model_utils.enhance(_TrainableKerasModel(keras_model, dummy_tensors))
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 491, in __init__
    inner_model.loss_weights, inner_model.metrics)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 381, in __init__
    federated_output, federated_local_outputs_type)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/api/computations.py", line 223, in federated_computation
    return computation_wrapper_instances.federated_computation_wrapper(*args)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 410, in __call__
    self._wrapper_fn)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 103, in _wrap
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper_instances.py", line 78, in _federated_computation_wrapper_fn
    suggested_name=name))
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/federated_computation_utils.py", line 76, in zero_or_one_arg_fn_to_building_block
    context_stack))
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 652, in <lambda>
    return lambda arg: _call(fn, parameter_type, arg)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 645, in _call
    return fn(arg)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 377, in federated_output
    type(metric), metric.get_config(), variables)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 260, in federated_aggregate_keras_metric
    @tff.tf_computation(member_type)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 415, in <lambda>
    return lambda fn: _wrap(fn, arg_type, self._wrapper_fn)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 103, in _wrap
    concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper_instances.py", line 44, in _tf_wrapper_fn
    target_fn, parameter_type, ctx_stack)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/tensorflow_serialization.py", line 278, in serialize_py_fn_as_tf_computation
    result = target(*args)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 652, in <lambda>
    return lambda arg: _call(fn, parameter_type, arg)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 645, in _call
    return fn(arg)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 278, in report
    t=metric_type, c=metric_config, e=e))
TypeError: Caught exception trying to call `<class 'tensorflow.python.keras.metrics.MeanMetricWrapper'>.from_config()` with config {'name': 'accuracy', 'dtype': 'float32'}. Confirm that <class 'tensorflow.python.keras.metrics.MeanMetricWrapper'>.__init__() has an argument for each member of the config.
Exception: __init__() missing 1 required positional argument: 'fn'

My dataset's label is a kind of two labels [0. 1.] and I used binary_crossentropy for loss function. But the accuracy gets back the error. I am sure it is related to multiple labels. The loss calculated without any problem when I remove the accuracy. Any help would be greatly appreciated.


Solution

  • TensorFlow Federated unfortunately isn't able to understand Keras models that have been compiled with string arguments. TFF requires the compile() call on the model be given instances of tf.keras.losses.Loss or tf.keras.metrics.Metric. It should be possible to change the last part of the code in question to:

    model.compile(optimizer=opt,
                  loss=tf.keras.losses.BinaryCrossentropy(),
                  metrics=[tf.keras.metrics.Accuracy()])
    

    Note that there shouldn't be a need to define a custom loss function, Keras provides a canned binary crossentropy.