Search code examples
pythontensorflowmachine-learningtensorflow-federatedfederated-learning

Custom model aggregator TensorFlow Federated


I am experimenting with TensorFlow Federated, simulating a training process with the FedAvg algorithm.

def model_fn():
  # Wrap a Keras model for use with TensorFlow Federated
  keras_model = get_uncompiled_model()

  # For the federated procedure, the model must be uncompiled
  return tff.learning.models.functional_model_from_keras(
        keras_model,
        loss_fn=tf.keras.losses.BinaryCrossentropy(),
        input_spec=(
              tf.TensorSpec(shape=[None, X_train.shape[1]], dtype=tf.float32),
              tf.TensorSpec(shape=[None], dtype=tf.int32)
        ),
        metrics_constructor=collections.OrderedDict(
              accuracy=tf.keras.metrics.BinaryAccuracy,
              precision=tf.keras.metrics.Precision,
              recall=tf.keras.metrics.Recall,
              false_positives=tf.keras.metrics.FalsePositives,
              false_negatives=tf.keras.metrics.FalseNegatives,
              true_positives=tf.keras.metrics.TruePositives,
              true_negatives=tf.keras.metrics.TrueNegatives
            )
  )

trainer = tff.learning.algorithms.build_weighted_fed_avg(
                      model_fn= model_fn(),
                      client_optimizer_fn=client_optimizer,
                      server_optimizer_fn=server_optimizer
                    )

I want to use custom weights to aggregate the clients' updates instead of using their number of samples. I know that tff.learning.algorithms.build_weighted_fed_avg() has a parameter called client_weighting, but the only value accepted is from the class tff.learning.ClientWeighting, which is an enum.

So, the only way to do that seems to be to write a custom WeightedAggregator. I've tried following this tutorial that explains how to write an unweighted aggregator, but I cannot make it work transforming it into a weighted one.

This is what I've tried to do:

@tff.tensorflow.computation
def custom_weighted_aggregate(values, weights):
    # Normalize client weights
    total_weight = tf.reduce_sum(weights)
    normalized_weights = weights / total_weight

    # Compute weighted sum of client updates
    weighted_sum = tf.nest.map_structure(
        lambda v: tf.reduce_sum(normalized_weights * v, axis=0),
        values
    )

    return weighted_sum

class CustomWeightedAggregator(tff.aggregators.WeightedAggregationFactory):
    def __init__(self):
        pass

    def create(self, value_type, weight_type):
        @tff.federated_computation
        def initialize():
            return tff.federated_value(0.0, tff.SERVER)

        @tff.federated_computation(
            initialize.type_signature.result,
            tff.FederatedType(value_type, tff.CLIENTS),
            tff.FederatedType(weight_type, tff.CLIENTS)
        )
        def next(state, value, weight):
            aggregate_value = tff.federated_map(custom_weighted_aggregate, (value, weight))
            return tff.templates.MeasuredProcessOutput(
                state, aggregate_value, tff.federated_value((), tff.SERVER)
            )

        return tff.templates.AggregationProcess(initialize, next)

    @property
    def is_weighted(self):
        return True

But I get the following error:

AggregationPlacementError: The "result" attribute of return type of next_fn must be placed at SERVER, but found {<float32[7],float32,float32[1],float32>}@CLIENTS.


Solution

  • To do cross-device reductions in TFF, we must use TFF's special intrinsic symbols--essentially, these 'register' certain reductions (e.g., the reduce_sum above) as special, so that they can be identified later as the ones that the user intended to use to express 'this is a reduction that should go cross-device now'.

    In TFF, pure tensorflow logic is always 'running locally', rather than extracted to run cross-device. This means that the tff.tensorflow.computation you have above (custom_weighted_aggregate) is really expressing a 'per-client reduction', rather than a cross-client reduction.

    One way you might express such a thing, if your values and weights are placed at clients, could be that captured in this implementation. Or, alternatively, I believe that implementation should be directly usable from this symbol, whose create symbol should return you an aggregation process to whom you can pass custom weights.