I'm trying to use tensorflow-federated to select different subset of weights at the server and send them to the clients. The clients then would train and send back the trained weights. The server aggregates the results and starts a new communication round.
The main problem is that I cannot access the numpy version of the weights and therefore I don't know how to access a subset of them for each layer. I tried using tf.gather_nd and tf.tensor_scatter_nd_update to perform selection and update, but they only work for tensors, and not lists of tensors (as the server_state is in tensorflow-federated).
Does anyone have any hint to solve this problem? Is it even possible to send different weights to each client?
If I follow correctly, a way to write the high-level computation being described in the TFF type shorthand would be:
@tff.federated_computation(...)
def run_one_round(server_state, client_datasets):
weights_subset = tff.federated_map(subset_fn, server_state)
clients_weights_subset = tff.federated_broadcast(weights_subset)
client_models = tff.federated_map(client_training_fn,
(clients_weights_subset, client_datasets))
aggregated_update = tff.federated_aggregate(client_models, ...)
new_server_state = tff.federated_map(apply_aggregated_update_fn, server_state)
return new_server_state
If this is true, it seems like the majority of the work needs to happen in subset_fn
which takes the server state and returns a subset of the global mode weights. Generally a model is a structure (list
or dict
, possibly nested) of tf.Tensor
, which as you observed cannot be used as an argument to tf.gather_nd
or tf.tensor_scatter_nd_update
. However, they can be be applied pointwise to the structure of tensors uses tf.nest.map_structure
. For example, selecting the value at [0, 0] from a nested structure of three tensors:
import tensorflow as tf
import pprint
struct_of_tensors = {
'trainable': [tf.constant([[2.0, 4.0, 6.0]]), tf.constant([[5.0]])],
'non_trainable': [tf.constant([[1.0]])],
}
pprint.pprint(tf.nest.map_structure(
lambda tensor: tf.gather_nd(params=tensor, indices=[[0, 0]]),
struct_of_tensors))
>>> {'non_trainable': [<tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>],
'trainable': [<tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.], dtype=float32)>,
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>]}