Search code examples
tensorflowpython-3.6tensorflow-federated

TensorFlow Federated (TFF) TypeError in tff.templates.IterativeProcess.next() when clients_per_round exceed 99


I implemented a custom federated learning GAN training loop with TFF similar to this code by Google Research.

The client data for a particular training round is found using the following code snippet:

def client_dataset_fn():
    # Sample clients and data
    sampled_clients = np.random.choice(train_data.client_ids, size=cfg.clients_per_round, replace=False)
    datasets = [(next(client_gen_inputs_iterator),
                 train_data.create_tf_dataset_for_client(client_id).take(cfg.n_critic))
                for client_id in sampled_clients]
    return datasets

client_noise_inputs, client_real_data = zip(*client_dataset_fn())

This works perfectly up until cfg.clients_per_round is set to 99. When it is set to 100 or a larger value (with the total number of clients being larger of course), I receive the following error:

Traceback (most recent call last):
  File "main.py", line 109, in main
    metrics = run_single_trial(train_data, test_data, cfg)
  File "/mnt/workspace/tff/GAN/federated/fedgan_main.py", line 73, in run_single_trial
    metrics = train_loop(iterative_process, server_dataset_fn, client_dataset_fn, model, eval_hook_fn, cfg)
  File "/mnt/workspace/tff/GAN/federated/fedgan_main.py", line 124, in train_loop
    client_real_data)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/computation/function_utils.py", line 525, in __call__
    return context.invoke(self, arg)
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 49, in wrapped_f
    return Retrying(*dargs, **dkw).call(f, *args, **kw)
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 206, in call
    return attempt.get(self._wrap_exception)
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 247, in get
    six.reraise(self.value[0], self.value[1], self.value[2])
  File "/usr/local/lib/python3.6/dist-packages/six.py", line 703, in reraise
    raise value
  File "/usr/local/lib/python3.6/dist-packages/retrying.py", line 200, in call
    attempt = Attempt(fn(*args, **kwargs), attempt_number, False)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/execution_context.py", line 226, in invoke
    _ingest(executor, unwrapped_arg, arg.type_signature)))
  File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
    return future.result()
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 396, in _wrapped
    return await coro
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/execution_context.py", line 111, in _ingest
    ingested = await asyncio.gather(*ingested)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/execution_context.py", line 116, in _ingest
    return await executor.create_value(val, type_spec)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 201, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/reference_resolving_executor.py", line 294, in create_value
    value, type_spec))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 201, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 111, in create_value
    self._target_executor.create_value(value, type_spec))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/thread_delegating_executor.py", line 105, in _delegate
    result_value = await _delegate_with_trace_ctx(coro, self._event_loop)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 396, in _wrapped
    return await coro
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/tracing.py", line 201, in async_trace
    result = await fn(*fn_args, **fn_kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/federating_executor.py", line 394, in create_value
    return await self._strategy.compute_federated_value(value, type_spec)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/executors/federated_composing_strategy.py", line 279, in compute_federated_value
    py_typecheck.check_type(value, list)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/common_libs/py_typecheck.py", line 41, in check_type
    type_string(type_spec), type_string(type(target))))
TypeError: Expected list, found tuple.

During debugging, I looked at the target variable in the final line of the traceback and found it to be the abovementioned client_real_data and client_noise_inputs. Their types are in fact tuples not lists, however, this does not change with different numbers of cfg.clients_per_round. The only usage of cfg.clients_per_round is shown above in the random choice. I really cannot explain why this is happening, maybe somebody out there has experienced something similar and can help me out.

My used package versions are as follows:

  • Python 3.6.9 or 3.8.10 (checked both)
  • tensorflow 2.5.1
  • tensorflow-federated 0.19.0
  • retrying 1.3.3
  • six 1.15.0

As a workaround I now manually change the data type of client_noise_inputs and client_real_data using list(tuple_var), but I am still curious as to why the list is required somehow.


Solution

  • (Copying and pasting from original on GitHub)

    This seems to me to be an implementation distinction between the federated_composing_strategy and the federated_resolving_strategy. IIRC, by default we don't inject a composing executor into your stack until you hit 100 clients--which would be the source of this exciting mystery.

    In particular, the composing strategy is programmed against the assumption that the incoming clients-placed value is represented as a list, whereas the resolving strategy codes against a much more flexible set of containers.

    It's not wild to coerce your clients-placed value to a list--we also could extend the permitted representation of clients-placed values in the composing executor to match that in the resolving one, possibly pulling the appropriate logic to a shared place like here. I think its a contribution wed be very happy to accept if youre up for it!