Search code examples
tensorflowtensorflow-federated

`tensorflow_federated.learning.from_keras_model()` no longer contains 'dummy_batch' keyword?


I ran the tensorflow federated tutorial code on https://colab.research.google.com/github/tensorflow/federated/blob/v0.13.1/docs/tutorials/federated_learning_for_image_classification.ipynb . I got this error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-17-d5336a451ad0> in <module>()
      2     model_fn,
      3     client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
----> 4     server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

2 frames
<ipython-input-16-7b97120f96c2> in model_fn()
      7       dummy_batch=sample_batch,
      8       loss=tf.keras.losses.SparseCategoricalCrossentropy(),
----> 9       metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

TypeError: from_keras_model() got an unexpected keyword argument 'dummy_batch'

The provided notebook updates the tensorflow_federated to latest version, so tff version is 0.14.0. So in version 0.14.0, we no longer need to feed the dummy batch? Is usual tff working pipline has changed?

P.S. Downgrading tensorflow_federated to version 0.13.1 works.


Solution

  • This is correct; the dummy_batch keyword was deprecated in this commit in favor of input_spec, for extra flexibility.

    There are several ways one might get their hands on an input_spec, including computing it directly from the arrays or tensors to be fed into the model, but the simplest is to access the element_spec attribute of an associated tf.data.Dataset which the model will train on.

    As for the link to the colab itself, looks like when TFF was updating its links as part of its release today, it forgot to include the v when tagging the commit on GitHub. Links are now updated, and this should take you to a version of the colab that works with 0.14.0.