Search code examples
pythontensorflowtensorflow-federated

Access and modify weights sent from client on the server tensorflow federated


I'm using Tensorflow Federated, but i'm actually have some problem while trying to executes some operation on the server after reading the client update.

This is the function

@tff.federated_computation(federated_server_state_type,
                           federated_dataset_type)
def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.
    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
    tf.print("run_one_round")
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, server_message_at_client))

    weight_denom = client_outputs.client_weight


    tf.print(client_outputs.weights_delta)
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
    round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

    return server_state, round_loss_metric, client_outputs.weights_delta.comp

I want to print the client_outputs.weights_delta and doing some operation on the weights that the client sent to the server before using the tff.federated_mean but i don't get how to do so.

When i try to print i get this

Call(Intrinsic('federated_map', FunctionType(StructType([FunctionType(StructType([('weights_delta', StructType([TensorType(tf.float32, [5, 5, 1, 32]), TensorType(tf.float32, [32]), ....]) as ClientOutput, PlacementLiteral('clients'), False)))]))

Any way to modify those elements?

I tried with using return client_outputs.weights_delta.comp doing the modification in the main (i can do that) and then i tried to invocate a new method for doing the rest of the operations for the server update, but the error is:

AttributeError: 'IterativeProcess' object has no attribute 'calculate_federated_mean' where calculate_federated_mean was the name of the new function i created.

This is the main:

 for round_num in range(FLAGS.total_rounds):
        print("--------------------------------------------------------")
        sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False)
        sampled_train_data = [train_data.create_tf_dataset_for_client(client) for client in sampled_clients]

        server_state, train_metrics, value_comp = iterative_process.next(server_state, sampled_train_data)

        print(f'Round {round_num}')
        print(f'\tTraining loss: {train_metrics:.4f}')
        if round_num % FLAGS.rounds_per_eval == 0:
            server_state.model_weights.assign_weights_to(keras_model)
            accuracy = evaluate(keras_model, test_data)
            print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
            tf.print(tf.compat.v2.summary.scalar("Accuracy", accuracy * 100.0, step=round_num))

Based on the simple_fedavg project from github [Tensorflow Federated simple_fedavg][1] as basic project.

EDIT 1:

So, thanks to @Jakub Konecny i made some progress, but i have found a new problem that i don't actually understand.

So, if i use this client_update

@tf.function
def client_update(model, dataset, server_message, client_optimizer):
    """Performans client local training of `model` on `dataset`.
    Args:
      model: A `tff.learning.Model`.
      dataset: A 'tf.data.Dataset'.
      server_message: A `BroadcastMessage` from server.
      client_optimizer: A `tf.keras.optimizers.Optimizer`.
    Returns:
      A 'ClientOutput`.
    """
    model_weights = model.weights
    initial_weights = server_message.model_weights
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                          initial_weights)

    num_examples = tf.constant(0, dtype=tf.int32)
    loss_sum = tf.constant(0, dtype=tf.float32)
    # Explicit use `iter` for dataset is a trick that makes TFF more robust in
    # GPU simulation and slightly more performant in the unconventional usage
    # of large number of small datasets.
    for batch in iter(dataset):
        with tf.GradientTape() as tape:
            outputs = model.forward_pass(batch)
        grads = tape.gradient(outputs.loss, model_weights.trainable)
        client_optimizer.apply_gradients(zip(grads, model_weights.trainable))
        batch_size = tf.shape(batch['x'])[0]
        num_examples += batch_size
        loss_sum += outputs.loss * tf.cast(batch_size, tf.float32)

    weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                          model_weights.trainable,
                                          initial_weights.trainable)


    client_weight = tf.cast(num_examples, tf.float32)

    import sparse_ternary_compression
    sparsification_rate = 1
    testing_new = []
    #TODO Da non applicare alle bias
    for tensor in weights_delta:
        testing_new.append(sparse_ternary_compression.stc_compression(tensor, sparsification_rate))

    return ClientOutput(weights_delta, client_weight, loss_sum / client_weight, testing_new)

with those functions:

@tff.tf_computation
def stc_compression(original_tensor, sparsification_percentage):
    original_shape = tf.shape(original_tensor)
    tensor = tf.reshape(original_tensor, [-1])
    sparsification_percentage = tf.cast(sparsification_percentage, tf.float64)
    sparsification_rate = tf.size(tensor) / 100 * sparsification_percentage
    sparsification_rate = tf.cast(sparsification_rate, tf.int32)
    new_shape = tensor.get_shape().as_list()
    if sparsification_rate == 0:
        sparsification_rate = 1
    mask = tf.cast(tf.abs(tensor) >= tf.math.top_k(tf.abs(tensor), sparsification_rate)[0][-1], tf.float32)
    inv_mask = tf.cast(tf.abs(tensor) < tf.math.top_k(tf.abs(tensor), sparsification_rate)[0][-1], tf.float32)
    tensor_masked = tf.multiply(tensor, mask)
    sparsification_rate = tf.cast(sparsification_rate, tf.float32)
    average = tf.reduce_sum(tf.abs(tensor_masked)) / sparsification_rate
    compressed_tensor = tf.add(tf.multiply(average, mask) * tf.sign(tensor), tf.multiply(tensor_masked, inv_mask))
    negatives = tf.where(compressed_tensor < 0)
    positives = tf.where(compressed_tensor > 0)
    return negatives, positives, average, original_shape, new_shape

@tff.tf_computation
def stc_decompression(negatives, positives, average, original_shape, new_shape):
    decompressed_tensor = tf.zeros(new_shape, tf.float32)
    average_values_negative = tf.fill([tf.shape(negatives)[0], ], -average)
    average_values_positive = tf.fill([tf.shape(positives)[0], ], average)
    decompressed_tensor = tf.tensor_scatter_nd_update(decompressed_tensor, negatives, average_values_negative)
    decompressed_tensor = tf.tensor_scatter_nd_update(decompressed_tensor, positives, average_values_positive)
    decompressed_tensor = tf.reshape(decompressed_tensor, original_shape)
    return decompressed_tensor


@tff.tf_computation
def testing_new_list(list):
    testing = []
    for index in list:
        testing.append(
            stc_decompression(index[0], index[1],
                              index[2], index[3],
                              index[4]))

    return testing

called like so inside the run_one_round function

@tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.
        Args:
          server_state: A `ServerState`.
          federated_dataset: A federated `tf.data.Dataset` with placement
            `tff.CLIENTS`.
        Returns:
          A tuple of updated `ServerState` and `tf.Tensor` of average loss.
        """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        weight_denom = client_outputs.client_weight

        import sparse_ternary_compression
        testing = tff.federated_map(sparse_ternary_compression.testing_new_list, client_outputs.test)

        # round_model_delta indica i pesi che vengono usati su server_update. Quindi è quello che va cambiato
        round_model_delta = tff.federated_mean(
            client_outputs.weights_delta, weight=weight_denom)

        server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
        round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

        return server_state, round_loss_metric, testing

but i get this exception

Traceback (most recent call last):
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/main.py", line 214, in <module>
    app.run(main)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/main.py", line 171, in main
    iterative_process = simple_fedavg_tff.build_federated_averaging_process(
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tff.py", line 95, in build_federated_averaging_process
    def client_update_fn(tf_dataset, server_message):
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 478, in __call__
    wrapped_func = self._strategy(
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 216, in __call__
    result = fn_to_wrap(*args, **kwargs)
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tff.py", line 98, in client_update_fn
    return client_update(model, tf_dataset, server_message, client_optimizer)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 933, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 763, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3050, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3279, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.autograph.pyct.error_utils.KeyError: in user code:

        /mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tf.py:222 client_update  *
            testing_new.append(sparse_ternary_compression.stc_compression(tensor, sparsification_rate))
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/computation/function_utils.py:608 __call__  *
            return concrete_fn(packed_arg)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/computation/function_utils.py:525 __call__  *
            return context.invoke(self, arg)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context.py:54 invoke  *
            init_op, result = (
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py:1097 deserialize_and_call_tf_computation  *
            input_map = {
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3931 get_tensor_by_name  **
            return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3755 as_graph_element
            return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3795 _as_graph_element_locked
            raise KeyError("The name %s refers to a Tensor which does not "
    
        KeyError: "The name 'sub:0' refers to a Tensor which does not exist. The operation, 'sub', does not exist in the graph."
    
    
    Process finished with exit code 1

EDIT 2:

Fixed the problem above by changing the decorator of the functions stc_compression and stc_decompression from tff.tf_computation to tf.function. Now seems to work fine because, if i print the variable testing that i got from the return server_state, round_loss_metric, testing inside run_one_round i get the weights that i wanted from the start.


Solution

  • I think this reply to your other question I just wrote applies here, too.

    When you print client_outputs.weights_delta you get abstract representation fo a result of another computation, a primarily internal implementation detail of TFF.

    Write a tff.tf_computation-decorated method with TensorFlow code, which does the modification you need, and then invoke it using tff.federated_map operator from where you are trying to print the values.