This is my model, and I have implemented it once in TensorFlow.
def create_compiled_keras_model():
inputs = Input(shape=(7, 20, 1))
l0_c = Conv2D(32, kernel_size=(7, 7), padding='valid', activation='relu')(inputs)
l1_c = Conv2D(32, kernel_size=(1, 5), padding='same', activation='relu')(l0_c)
l1_p = AveragePooling2D(pool_size=(1, 2), strides=2, padding='same')(l1_c)
l2_c = Conv2D(64, kernel_size=(1, 4), padding='same', activation='relu')(l1_p)
l2_p = AveragePooling2D(pool_size=(1, 2), strides=2, padding='same')
l3_c = Conv2D(2, kernel_size=(1, 1), padding='valid', activation='sigmoid')(l2_p)
predictions = Flatten()(l3_c)
predictions = tf.cast(predictions, dtype='float32')
model = Model(inputs=inputs, outputs=predictions)
opt = Adam(lr=0.0005)
print(model.summary())
def loss_fn(y_true, y_pred):
return tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_pred, y_true))
model.compile(optimizer=opt,
loss=loss_fn,
metrics=['accuracy'])
return model
I get this error in TensorFlow Federated.
Traceback (most recent call last):
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 270, in report
keras_metric = metric_type.from_config(metric_config)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py", line 594, in from_config
return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'fn'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/fl.py", line 203, in <module>
quantization_part = FedAvgQ.build_federated_averaging_process(model_fn)
File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/new_fedavg_keras.py", line 195, in build_federated_averaging_process
stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py", line 351, in build_model_delta_optimizer_process
dummy_model_for_metadata = model_utils.enhance(model_fn())
File "/Users/amir/Documents/CODE/Python/FL/fl_dataset_khudemon/fl.py", line 196, in model_fn
return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 216, in from_compiled_keras_model
return model_utils.enhance(_TrainableKerasModel(keras_model, dummy_tensors))
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 491, in __init__
inner_model.loss_weights, inner_model.metrics)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 381, in __init__
federated_output, federated_local_outputs_type)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/api/computations.py", line 223, in federated_computation
return computation_wrapper_instances.federated_computation_wrapper(*args)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 410, in __call__
self._wrapper_fn)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 103, in _wrap
concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper_instances.py", line 78, in _federated_computation_wrapper_fn
suggested_name=name))
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/federated_computation_utils.py", line 76, in zero_or_one_arg_fn_to_building_block
context_stack))
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 652, in <lambda>
return lambda arg: _call(fn, parameter_type, arg)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 645, in _call
return fn(arg)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 377, in federated_output
type(metric), metric.get_config(), variables)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 260, in federated_aggregate_keras_metric
@tff.tf_computation(member_type)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 415, in <lambda>
return lambda fn: _wrap(fn, arg_type, self._wrapper_fn)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 103, in _wrap
concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper_instances.py", line 44, in _tf_wrapper_fn
target_fn, parameter_type, ctx_stack)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/tensorflow_serialization.py", line 278, in serialize_py_fn_as_tf_computation
result = target(*args)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 652, in <lambda>
return lambda arg: _call(fn, parameter_type, arg)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py", line 645, in _call
return fn(arg)
File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 278, in report
t=metric_type, c=metric_config, e=e))
TypeError: Caught exception trying to call `<class 'tensorflow.python.keras.metrics.MeanMetricWrapper'>.from_config()` with config {'name': 'accuracy', 'dtype': 'float32'}. Confirm that <class 'tensorflow.python.keras.metrics.MeanMetricWrapper'>.__init__() has an argument for each member of the config.
Exception: __init__() missing 1 required positional argument: 'fn'
My dataset's label is a kind of two labels [0. 1.]
and I used binary_crossentropy
for loss function. But the accuracy gets back the error. I am sure it is related to multiple labels. The loss calculated without any problem when I remove the accuracy. Any help would be greatly appreciated.
TensorFlow Federated unfortunately isn't able to understand Keras models that have been compiled with string arguments. TFF requires the compile()
call on the model be given instances of tf.keras.losses.Loss
or tf.keras.metrics.Metric
. It should be possible to change the last part of the code in question to:
model.compile(optimizer=opt,
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[tf.keras.metrics.Accuracy()])
Note that there shouldn't be a need to define a custom loss function, Keras provides a canned binary crossentropy.