Search code examples
tensorflow-federated

how to apply custom encoders to multiple clients at once? how to use custom encoders in run_one_round?


So my goal is basically implementing global top-k subsampling. Gradient sparsification is quite simple and I have already done this building on stateful clients example, but now I would like to use encoders as you have recommended here at page 28. Additionally I would like to average only the non-zero gradients, so say we have 10 clients but only 4 have nonzero gradients at a given position for a communication round then I would like to divide the sum of these gradients to 4, not 10. I am hoping to achieve this by summing gradients at numerator and masks, 1s and 0s, at denominator. Also moving forward I will add randomness to gradient selection so it is imperative that I create those masks concurrently with gradient selection. The code I have right now is

import tensorflow as tf

from tensorflow_model_optimization.python.core.internal import tensor_encoding as te


@te.core.tf_style_adaptive_encoding_stage
class GrandienrSparsificationEncodingStage(te.core.AdaptiveEncodingStageInterface):
  """An example custom implementation of an `EncodingStageInterface`.
  Note: This is likely not what one would want to use in practice. Rather, this
  serves as an illustration of how a custom compression algorithm can be
  provided to `tff`.
  This encoding stage is expected to be run in an iterative manner, and
  alternatively zeroes out values corresponding to odd and even indices. Given
  the determinism of the non-zero indices selection, the encoded structure does
  not need to be represented as a sparse vector, but only the non-zero values
  are necessary. In the decode mehtod, the state (i.e., params derived from the
  state) is used to reconstruct the corresponding indices.
  Thus, this example encoding stage can realize representation saving of 2x.
  """

  ENCODED_VALUES_KEY = 'stateful_topk_values'
  INDICES_KEY = 'indices'
  SHAPES_KEY = 'shapes'
  ERROR_COMPENSATION_KEY = 'error_compensation'

  def encode(self, x, encode_params):
    shapes_list = [tf.shape(y) for y in x]
    flattened = tf.nest.map_structure(lambda y: tf.reshape(y, [-1]), x)
    gradients = tf.concat(flattened, axis=0)
    error_compensation = encode_params[self.ERROR_COMPENSATION_KEY]
    
    gradients_and_error_compensation = tf.math.add(gradients, error_compensation)

    percentage = tf.constant(0.1, dtype=tf.float32)
    k_float = tf.multiply(percentage, tf.cast(tf.size(gradients_and_error_compensation), tf.float32))
    k_int = tf.cast(tf.math.round(k_float), dtype=tf.int32)

    values, indices = tf.math.top_k(tf.math.abs(gradients_and_error_compensation), k = k_int, sorted = False)
    indices = tf.expand_dims(indices, 1)
    sparse_gradients_and_error_compensation = tf.scatter_nd(indices, values, tf.shape(gradients_and_error_compensation))

    new_error_compensation = tf.math.subtract(gradients_and_error_compensation, sparse_gradients_and_error_compensation)
    state_update_tensors = {self.ERROR_COMPENSATION_KEY: new_error_compensation}
    
    encoded_x = {self.ENCODED_VALUES_KEY: values,
                 self.INDICES_KEY: indices,
                 self.SHAPES_KEY: shapes_list}

    return encoded_x, state_update_tensors

  def decode(self,
             encoded_tensors,
             decode_params,
             num_summands=None,
             shape=None):
    del num_summands, decode_params, shape  # Unused.
    flat_shape = tf.math.reduce_sum([tf.math.reduce_prod(shape) for shape in encoded_tensors[self.SHAPES_KEY]])
    sizes_list = [tf.math.reduce_prod(shape) for shape in encoded_tensors[self.SHAPES_KEY]]
    scatter_tensor = tf.scatter_nd(
        indices=encoded_tensors[self.INDICES_KEY],
        updates=encoded_tensors[self.ENCODED_VALUES_KEY],
        shape=[flat_shape])
    nonzero_locations = tf.nest.map_structure(lambda x: tf.cast(tf.where(tf.math.greater(x, 0), 1, 0), tf.float32) , scatter_tensor)
    reshaped_tensor = [tf.reshape(flat_tensor, shape=shape) for flat_tensor, shape in
            zip(tf.split(scatter_tensor, sizes_list), encoded_tensors[self.SHAPES_KEY])]
    reshaped_nonzero = [tf.reshape(flat_tensor, shape=shape) for flat_tensor, shape in
            zip(tf.split(nonzero_locations, sizes_list), encoded_tensors[self.SHAPES_KEY])]
    return  reshaped_tensor, reshaped_nonzero


  def initial_state(self):
    return {self.ERROR_COMPENSATION_KEY: tf.constant(0, dtype=tf.float32)}

  def update_state(self, state, state_update_tensors):
    return {self.ERROR_COMPENSATION_KEY: state_update_tensors[self.ERROR_COMPENSATION_KEY]}

  def get_params(self, state):
    encode_params = {self.ERROR_COMPENSATION_KEY: state[self.ERROR_COMPENSATION_KEY]}
    decode_params = {}
    return encode_params, decode_params

  @property
  def name(self):
    return 'gradient_sparsification_encoding_stage'

  @property
  def compressible_tensors_keys(self):
    return False

  @property
  def commutes_with_sum(self):
    return False

  @property
  def decode_needs_input_shape(self):
    return False

  @property
  def state_update_aggregation_modes(self):
    return {}

I have run some simple tests manually following the steps you outlined here at page 45. It works but I have some questions/problems.

  1. When I use list of tensors of same shape (ex:2 2x25 tensors) as input,x, of encode it works without any issues but when I try to use list of tensors of different shapes (2x20 and 6x10) it gives and error saying

InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [2,20] != values1.shape = [6,10] [Op:Pack] name: packed

How can I resolve this issue? As i said I want to use global top-k so it is essential I encode entire trainable model weights at once. Take the cnn model used here, all the tensors have different shapes.

  1. How can I do the averaging I described at the beginning? For example here you have done

mean_factory = tff.aggregators.MeanFactory( tff.aggregators.EncodedSumFactory(mean_encoder_fn), # numerator tff.aggregators.EncodedSumFactory(mean_encoder_fn), # denominator )

Is there a way to repeat this with one output of decode going to numerator and other going to denominator? How can I handle dividing 0 by 0? tensorflow has divide_no_nan function, can I use it somehow or do I need to add eps to each?

  1. How is partition handled when I use encoders? Does each client get a unique encoder holding a unique state for it? As you have discussed here at page 6 client states are used in cross-silo settings yet what happens if client ordering changes?

  2. Here you have recommended using stateful clients example. Can you explain this a bit further? I mean in the run_one_round where exactly encoders go and how are they used/combined with client update and aggregation?

  3. I have some additional information such as sparsity I want to pass to encode. What is the suggested method for doing that?


Solution

  • Here are some answers, hope it helps:

    1. If you want to treat all of the aggregated structure just as a single tensor, use concat_factory as the outermost aggregator. That will concatenate entire structure to a rank-1 Tensor at clients, and then unpack back to the original structure at the end. Example use: tff.aggregators.concat_factory(tff.aggregators.MeanFactory(...))

    Note the encoding stage objects are meant to work with a single tensor, so what you describe with identical tensors probably works only accidentally.

    1. There are two options.

      a. Modify the client training code such that the weights being passed to the weighted aggregator are already what you want it to be (zero/one mask). In the stateful clients example you link, that would be here. You will then get what you need by default (by summing the numerator).

      b. Modify UnweightedMeanFactory to do exactly the variant of averaging you describe and use that. Start would be modifying this

    2. (and 4.) I think that is what you would need to implement. The same way existing client states are initialized in the example here, you would need extend it to contain the aggregator states, and make sure those are sampled together with the clients, as done here. Then, to integrate the aggregators in the example you would need to replace this hard-coded tff.federated_mean. An example of such integration is in the implementation of tff.learning.build_federated_averaging_process, primarily here

    3. I am not sure what the question is. Perhaps get the previous working (seems like a prerequisite to me), and then clarify and ask in a new post?