Search code examples
tensorflow-federated

Is there a way to cast a federated value?


If I have a federated value, say {int32}@CLIENTS that I'd like to cast to {float32}@CLIENTS is there an easy way to do this? Thanks!


Solution

  • Tensor manipulating generally needs to occur inside a function decorated with tff.tf_computation. Since the types mentioned have placements (@CLIENTS) this likely is inside a tff.federated_computation decorated function, so the casting method would need to be called with tff.federated_map.

    Something like this:

    @tff.tf_computation
    def cast_to_float(x):
      return tf.cast(x, tf.float32)
    
    @tff.federated_computation(tff.FederatedType(int32, tff.CLIENTS))
    def my_func(a):
      a_float = tff.federated_map(cast_to_float, a)
      return a_float
    
    print(my_func.type_signature)
    
    >>> ({int32}@CLIENTS -> {float32}@CLIENTS)