I have the following code:
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=federated_train_data[0].element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
In this code, I want to average by sorting 20% of the items corresponding to the loss in descending order.
#server select in the top20% clients
selected_clients_weights = clinet_select(client_weights)
How can I extract loss for sorting clients?
A good starting point would be to look at the directory tensorflow_federated/python/examples/simple_fedavg/
and see how Federated Averaging is implemented.
To extend this to average only the top 20% based on loss will require two things:
client_update
function, in this case a loss
value.tff.federated_mean
aggregation with a call to tff.federated_collect
. This will return a sequence. This could then be sorted (possibly by weight) and averaged inside a new tff.tf_computation
decorated method that is applied to the result of tff.federated_collect
with tff.federated_map
.