Search code examples
pythontensorflowtensorflow2.0tensorflow-federated

TensorFlow Federated: How to tune non-IIDness in federated dataset?


I am testing some algorithms in TensorFlow Federated (TFF). In this regard, I would like to test and compare them on the same federated dataset with different "levels" of data heterogeneity, i.e. non-IIDness.

Hence, I would like to know whether there is any way to control and tune the "level" of non-IIDness in a specific federated dataset, in an automatic or semi-automatic fashion, e.g. by means of TFF APIs or just traditional TF API (maybe inside the Dataset utils).

To be more practical: for instance, the EMNIST federated dataset provided by TFF has 3383 clients with each one of them having their handwritten characters. However, these local dataset seems to be quite balanced in terms of number of local examples and in terms of represented classes (all classes are, more or less, represented locally). If I would like to have a federated dataset (e.g., starting by the TFF's EMNIST one) that is:

  • Patologically non-IID, for example having clients that hold only one class out of N classes (always referring to a classification task). Is this the purpose of tff.simulation.datasets.build_single_label_dataset documentation here. If so, how should I use it from a federated dataset such as the ones already provided by TFF?;
  • Unbalanced in terms of the amount of local examples (e.g., one client has 10 examples, another one has 100 examples);
  • Both the possibilities;

how should I proceed inside the TFF framework to prepare a federated dataset with those characteristics?

Should I do all the stuff by hand? Or do some of you have some advices to automate this process?

An additional question: in this paper "Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification", by Hsu et al., they exploit the Dirichlet distribution to synthesize a population of non-identical clients, and they use a concentration parameter to control the identicalness among clients. This seems an wasy-to-tune way to produce datasets with different levels of heterogeneity. Any advice about how to implement this strategy (or a similar one) inside the TFF framework, or just in TensorFlow (Python) considering a simple dataset such as the EMNIST, would be very useful too.

Thank you a lot.


Solution

  • For Federated Learning simulations, its quite reasonable to setup the client datasets in Python, in the experiment driver, to achieve the desired distributions. At some high-level, TFF handles modeling data location ("placements" in the type system) and computation logic. Re-mixing/generating a simulation dataset is not quite core to the library, though there are helpful libraries as you've found. Doing this directly in python by manipulating the tf.data.Dataset and then "pushing" the client datasets into a TFF computation seems straightforward.

    Label non-IID

    Yes, tff.simulation.datasets.build_single_label_dataset is intended for this purpose.

    It takes a tf.data.Dataset and essentially filters out all examples that don't match desired_label values for the label_key (assuming the dataset yields dict like structures).

    For EMNIST, to create a dataset of all the ones (regardless of user), this could be achieved by:

    train_data, _ = tff.simulation.datasets.emnist.load_data()
    ones = tff.simulation.datasets.build_single_label_dataset(
      train_data.create_tf_dataset_from_all_clients(),
      label_key='label', desired_label=1)
    print(ones.element_spec)
    >>> OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
    print(next(iter(ones))['label'])
    >>> tf.Tensor(1, shape=(), dtype=int32)
    

    Data imbalance

    Using a combination of tf.data.Dataset.repeat and tf.data.Dataset.take can be used to create data imbalances.

    train_data, _ = tff.simulation.datasets.emnist.load_data()
    datasets = [train_data.create_tf_dataset_for_client(id) for id in train_data.client_ids[:2]]
    print([tf.data.experimental.cardinality(ds).numpy() for ds in datasets])
    >>> [93, 109]
    datasets[0] = datasets[0].repeat(5)
    datasets[1] = datasets[1].take(5)
    print([tf.data.experimental.cardinality(ds).numpy() for ds in datasets])
    >>> [465, 5]