Search code examples
tensorflow-federated

Is there a way for TFF clients to have internal states?


The code in the TFF tutorials and in the research projects I see generally only keep track of server states. I’d like there to be internal client states (for instance, additional client internal neural networks which are completely decentralized and don’t update in a federated manner) that would influence the federated client computations.

However, in the client computations I have seen, they are only functions of the server states and the data. Is it possible to accomplish the above?


Solution

  • Yup, this is easy to express in TFF, and will execution just fine in the default execution stacks.

    As you've noticed, the TFF repository generally has examples of cross-device Federated Learning (Kairouz et. al 2019). Generally we talk about the state have tff.SERVER placement, and the function signature for one "round" of federated learning has the structure (for details about TFF's type shorthand, see the Federated data section of the tutorials):

    (<State@SERVER, {Dataset}@CLIENTS> -> State@Server)
    

    We can represent stateful client by simply extending the signature:

    (<State@SERVER, {State}@Clients, {Dataset}@CLIENTS> -> <State@Server, {State}@Clients>)
    

    Implementing a version of Federated Averaging (McMahan et. al 2016) that includes a client state object might look something like:

    @tff.tf_computation(
      model_type,
      client_state_type,  # additional state parameter
      client_data_type)
    def client_training_fn(model, state, dataset):
      model_update, new_state = # do some local training
      return model_update, new_state # return a tuple including updated state
    
    @tff.federated_computation(
      tff.FederatedType(server_state_type, tff.SERVER),
      tff.FederatedType(client_state_type , tff.CLIENTS),  # new parameter for state
      tff.FederatedType(client_data_type , tff.CIENTS))
    def run_fed_avg(server_state, client_states, client_datasets):
      client_initial_models = tff.federated_broadcast(server_state.model)
      client_updates, new_client_state = tff.federated_map(client_training_fn,
        # Pass the client states as an argument.
        (client_initial_models, client_states, client_datasets))
      average_update = tff.federated_mean(client_updates)
      new_server_state = tff.federated_map(server_update_fn, (server_state, average_update))
      # Make sure to return the client states so they can be used in later rounds.
      return new_server_state, new_client_states
    

    The invocation of run_fed_avg would require passing a Python list of tensors/structures for each client participating in a round, and the result fo the method invocation will be the server state, and a list of client states.