Search code examples
tensorflowtensorflow-federated

How can I extract 20% descending loss in this code?


I have the following code:

def model_fn():
keras_model = create_keras_model()
 return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In this code, I want to average by sorting 20% of the items corresponding to the loss in descending order.

 #server select in the top20% clients
selected_clients_weights = clinet_select(client_weights)

How can I extract loss for sorting clients?


Solution

  • A good starting point would be to look at the directory tensorflow_federated/python/examples/simple_fedavg/ and see how Federated Averaging is implemented.

    To extend this to average only the top 20% based on loss will require two things:

    1. Add an additional output from the client_update function, in this case a loss value.
    2. Replace the tff.federated_mean aggregation with a call to tff.federated_collect. This will return a sequence. This could then be sorted (possibly by weight) and averaged inside a new tff.tf_computation decorated method that is applied to the result of tff.federated_collect with tff.federated_map.