An error is being generated while training a federated model that uses hub.KerasLayer. The details of error and stack trace is given below. The complete code is available of gist Help and suggestion in this regard would be appreciated. Thanks.
from tensorflow import keras
def create_keras_model():
encoder = hub.load("")
return tf.keras.models.Sequential([
hub.KerasLayer(encoder, input_shape=[],dtype=tf.string,trainable=True),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dense(16, activation='relu'),
keras.layers.Dense(1, activation='sigmoid'),
def model_fn():
# We _must_ create a new model here, and _not_ capture it from an external
# scope. TFF will call this within different graph contexts.
keras_model = create_keras_model()
return tff.learning.from_keras_model(
# Building the Federated Averaging Process
iterative_process = tff.learning.build_federated_averaging_process(
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
state = iterative_process.initialize()
state, metrics =, federated_train_data)
print('round 1, metrics={}'.format(metrics))
UnimplementedError Traceback (most recent call last)
<ipython-input-80-39d62fa827ea> in <module>()
----> 1 state, metrics =, federated_train_data)
2 print('round 1, metrics={}'.format(metrics))
119 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/ in
quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
UnimplementedError: Cast string to float is not supported
[[{{node StatefulPartitionedCall_1/StatefulPartitionedCall/Cast_1}}]]
[[import/StatefulPartitionedCall_3/ReduceDataset]] [Op:__inference_wrapped_function_65986]
Function call stack:
wrapped_function -> wrapped_function -> wrapped_function
The issues has now been resolved. The error was thrown because 'label' was getting passed as tf.string instead of tf.int32. Explicit casting resolved this.