In Tensorflow Federated (TFF), you can pass to the tff.learning.build_federated_averaging_process
a broadcast_process
and an aggregation_process
, which can embed customized encoders e.g. to apply custom compressions.
Getting to the point of my question, I am trying to implement an encoder to sparsify model updates/model weights.
I am trying to build such an encoder by implementing the EncodingStageInterface
, from tensorflow_model_optimization.python.core.internal
.
However, I am struggling to implement a (local) state to accumulate the zeroed-out coordinates of model updates/model weights round by round. Note that this state should not be communicated, and just need to be maintained locally (so the AdaptiveEncodingStageInterface
should not be helpful). In general, the question is how to maintain a local state inside an Encoder to be then passed to the fedavg process.
I attach the code of my encoder implementation (that, besides the state I would like to add, works fine as stateless as expected). I then attach the excerpt of my code where I use the encoder implementation. If I decomment the commented parts in stateful_encoding_stage_topk.py the code does not work: I can't figure out how manage the state (that is a Tensor) in TF non eager mode.
stateful_encoding_stage_topk.py
import tensorflow as tf
import numpy as np
from tensorflow_model_optimization.python.core.internal import tensor_encoding as te
@te.core.tf_style_encoding_stage
class StatefulTopKEncodingStage(te.core.EncodingStageInterface):
ENCODED_VALUES_KEY = 'stateful_topk_values'
INDICES_KEY = 'indices'
def __init__(self):
super().__init__()
# Here I would like to init my state
#self.A = tf.zeros([800], dtype=tf.float32)
@property
def name(self):
"""See base class."""
return 'stateful_topk'
@property
def compressible_tensors_keys(self):
"""See base class."""
return [self.ENCODED_VALUES_KEY]
@property
def commutes_with_sum(self):
"""See base class."""
return True
@property
def decode_needs_input_shape(self):
"""See base class."""
return True
def get_params(self):
"""See base class."""
return {}, {}
def encode(self, x, encode_params):
"""See base class."""
del encode_params # Unused.
dW = tf.reshape(x, [-1])
# Here I would like to retrieve the state
A = tf.zeros([800], dtype=tf.float32)
#A = self.residual
dW_and_A = tf.math.add(A, dW)
percentage = tf.constant(0.4, dtype=tf.float32)
k_float = tf.multiply(percentage, tf.cast(tf.size(dW), tf.float32))
k_int = tf.cast(tf.math.round(k_float), dtype=tf.int32)
values, indices = tf.math.top_k(tf.math.abs(dW_and_A), k = k_int, sorted = False)
indices = tf.expand_dims(indices, 1)
sparse_dW = tf.scatter_nd(indices, values, tf.shape(dW_and_A))
# Here I would like to update the state
A_updated = tf.math.subtract(dW_and_A, sparse_dW)
#self.A = A_updated
encoded_x = {self.ENCODED_VALUES_KEY: values,
self.INDICES_KEY: indices}
return encoded_x
def decode(self,
encoded_tensors,
decode_params,
num_summands=None,
shape=None):
"""See base class."""
del decode_params, num_summands # Unused.
indices = encoded_tensors[self.INDICES_KEY]
values = encoded_tensors[self.ENCODED_VALUES_KEY]
tensor = tf.fill([800], 0.0)
decoded_values = tf.tensor_scatter_nd_update(tensor, indices, values)
return tf.reshape(decoded_values, shape)
def sparse_quantizing_encoder():
encoder = te.core.EncoderComposer(
StatefulTopKEncodingStage() )
return encoder.make()
fedavg_with_sparsification.py
[...]
def sparsification_broadcast_encoder_fn(value):
spec = tf.TensorSpec(value.shape, value.dtype)
return te.encoders.as_simple_encoder(te.encoders.identity(), spec)
def sparsification_mean_encoder_fn(value):
spec = tf.TensorSpec(value.shape, value.dtype)
if value.shape.num_elements() == 800:
return te.encoders.as_gather_encoder(
stateful_encoding_stage_topk.sparse_quantizing_encoder(), spec)
else:
return te.encoders.as_gather_encoder(te.encoders.identity(), spec)
encoded_broadcast_process = (
tff.learning.framework.build_encoded_broadcast_process_from_model(
model_fn, sparsification_broadcast_encoder_fn))
encoded_mean_process = (
tff.learning.framework.build_encoded_mean_process_from_model(
model_fn, sparsification_mean_encoder_fn))
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.004),
client_weight_fn=lambda _: tf.constant(1.0),
broadcast_process=encoded_broadcast_process,
aggregation_process=encoded_mean_process)
[...]
I am using:
I'll try to answer in two parts; (1) top_k encoder without state and (2) realizing the stateful idea you seem to want in TFF.
(1)
To get the TopKEncodingStage
working without state, I see a few details to change.
The commutes_with_sum
property should be set to False
. In pseudo-code, its meaning is whether sum_x(decode(encode(x))) == decode(sum_x(encode(x)))
. This is not true for the representation your encode
method returns -- summing the indices
would not work well. I think implementation of the decode
method can be simplified to
return tf.scatter_nd(
indices=encoded_tensors[self.INDICES_KEY],
updates=encoded_tensors[self.ENCODED_VALUES_KEY],
shape=shape)
(2)
What you refer to cannot be achieved in this manner using tff.learning.build_federated_averaging_process
. The process returned by this method does not have any mechanism for maintaining client/local state. Whatever is the state expressed in your StatefulTopKEncodingStage
would end up being the server state, not local state.
To work with the client/local state, you may need to write more custom code. For a starter, see examples/stateful_clients
which you can adapt to store the state you refer to.
Keep in mind that in TFF, this will need to be represented as functional transformations. Storing values in attributes of a class and use them elsewhere can lead to surprising errors.