Search code examples
tensorflow-federated

How to gather all client weights at server in TFF?


I am trying to implement a custom aggregation using TFF by changing the code from this tutorial . I would like to rewrite next_fn so that all the client weights are placed at the server for further computations. As federated_collect was removed from tff-nightly, I am trying to do that using federated_aggregate.

This is what I have so far:

def accumulate(x, y):
    x.append(y)
    return x


def merge(x, y):
    x.extend(y)
    return y


@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_state, federated_dataset):
    server_weights_at_client = tff.federated_broadcast(
        server_state.trainable_weights)
    client_deltas = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client))

    z = []
    agg_result = tff.federated_aggregate(client_deltas, z,
                                         accumulate=tff.tf_computation(accumulate),
                                         merge=tff.tf_computation(merge),
                                         report=tff.tf_computation(lambda x: x))

    new_weights = do_smth_with_result(agg_result)
    server_state = tff.federated_map(
        server_update_fn, (server_state, new_weights))
    return server_state

However this results in the following Exception:

  File "/home/yana/Documents/Uni/Thesis/grufedatt_try.py", line 351, in <module>
    def next_fn(server_state, federated_dataset):
  File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 494, in __call__
    wrapped_func = self._strategy(
  File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 222, in __call__
    result = fn_to_wrap(*args, **kwargs)
  File "/home/yana/Documents/Uni/Thesis/grufedatt_try.py", line 358, in next_fn
    agg_result = tff.federated_aggregate(client_deltas, z,
  File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/federated_context/intrinsics.py", line 140, in federated_aggregate
    raise TypeError(
TypeError: Expected parameter `accumulate` to be of type (<<<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>,<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>> -> <<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>), but received (<<>,<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>> -> <<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>) instead.

Solution

  • Try using tff.aggregators.federated_sample with max_num_samples being equal to the number of clients you have.

    That should be a simple drop-in replacement for how you would previously use tff.federated_collect.


    In your accumulate, the issue is that you are changing number of tensors the accumulator would contain, so you get an error when accumulating more than a single accumuland. If you would want to go this way though, for a rank-1 accumuland with k elements, you could probably do something like the following instead:

    @tff.tf_computation(tff.types.TensorType(tf.float32, [None, k]),
                        tff.types.TensorType(tf.float32, [k]))
    def accumulate(accumulator, accumuland):
      return tf.concat([accumulator, tf.expand_dims(accumuland, axis=0)], axis=0)