Is there any way in federated-tensorflow to make clients train the model for multiple epochs on their dataset? I found on the tutorials that a solution could be modifying the dataset by running dataset.repeat(NUMBER_OF_EPOCHS), but why should I modify the dataset?
The tf.data.Dataset
is the TF2 way of setting this up. It maybe useful to think about the code as modifying the "data pipeline" rather than the "dataset" itself.
https://www.tensorflow.org/guide/data and particularly the section https://www.tensorflow.org/guide/data#processing_multiple_epochs can be useful pointers.
At a high-level, the tf.data
API sets up a stream of examples. Repeats (multiple epochs) of that stream can be configured as well.
dataset = tf.data.Dataset.range(5)
for x in dataset:
print(x) # prints 0, 1, 2, 3, 4 on separate lines.
repeated_dataset = dataset.repeat(2)
for x in repeated_dataset:
print(x) # same as above, but twice
shuffled_repeat_dataset = dataset.shuffle(
buffer_size=5, reshuffle_each_iteration=True).repeat(2)
for x in repeated_dataset:
print(x) # same as above, but twice, with different orderings.