How to save the model in the blow code
if you want to run the code, please visit https://github.com/tensorflow/federated and download federated_learning_for_image_classification.ipynb.
I will appreciate it if you told me how to save the model of federated learning in tutorials federated_learning_for_image_classification.ipynb.
from __future__ import absolute_import, division, print_function
import tensorflow_federated as tff
from matplotlib import pyplot as plt
import tensorflow as tf
import six
import numpy as np
from six.moves import range
import warnings
import collections
import nest_asyncio
import h5py_character
from tensorflow.keras import layers
nest_asyncio.apply()
warnings.simplefilter('ignore')
tf.compat.v1.enable_v2_behavior()
np.random.seed(0)
NUM_CLIENTS = 1
NUM_EPOCHS = 1
BATCH_SIZE = 20
SHUFFLE_BUFFER = 500
num_classes = 3755
if six.PY3:
tff.framework.set_default_executor(
tff.framework.create_local_executor(NUM_CLIENTS))
data_train = h5py_character.load_characters_data()
print(len(data_train.client_ids))
example_dataset = data_train.create_tf_dataset_for_client(
data_train.client_ids[0])
def preprocess(dataset):
def element_fn(element):
# element['data'] = tf.expand_dims(element['data'], axis=-1)
return collections.OrderedDict([
# ('x', tf.reshape(element['data'], [-1])),
('x', tf.reshape(element['data'], [64, 64, 1])),
('y', tf.reshape(element['label'], [1])),
])
return dataset.repeat(NUM_EPOCHS).map(element_fn).shuffle(
SHUFFLE_BUFFER).batch(BATCH_SIZE)
preprocessed_example_dataset = preprocess(example_dataset)
print(iter(preprocessed_example_dataset).next())
sample_batch = tf.nest.map_structure(
lambda x: x.numpy(), iter(preprocessed_example_dataset).next())
def make_federated_data(client_data, client_ids):
return [preprocess(client_data.create_tf_dataset_for_client(x))
for x in client_ids]
sample_clients = data_train.client_ids[0:NUM_CLIENTS]
federated_train_data = make_federated_data(data_train, sample_clients)
def create_compiled_keras_model():
model = tf.keras.Sequential([
layers.Conv2D(input_shape=(64, 64, 1), filters=64, kernel_size=(3, 3), strides=(1, 1),
padding='same', activation='relu'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Conv2D(filters=256, kernel_size=(3, 3), padding='same'),
layers.MaxPool2D(pool_size=(2, 2), padding='same'),
layers.Flatten(),
layers.Dense(1024, activation='relu'),
layers.Dense(3755, activation='softmax')
])
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
# metrics=['accuracy'])
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return model
def model_fn():
keras_model = create_compiled_keras_model()
global model_to_save
model_to_save = keras_model
print(keras_model.summary())
return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round 1, metrics={}'.format(metrics))
for round_num in range(2, 110):
state, metrics = iterative_process.next(state, federated_train_data)
print('round {:2d}, metrics={}'.format(round_num, metrics))
Roughly, we will be using the object here, and its save_checkpoint
/load_checkpoint
methods. In particular, you can instantiate a FileCheckpointManager
, and ask it to save state
(almost) directly.
state
in your example is an instance of tff.python.common_libs.anonymous_tuple.AnonymousTuple
(IIRC), which is not compatible with tf.convert_to_tensor
, as is needed by save_checkpoint
and declared in its docstring. The general solution often used in TFF research code is to introduce a Python attr
s class to convert away from the anonymous tuple as soon as state is returned--see here for an example.
Assuming the above, the following sketch should work:
# state assumed an anonymous tuple, previously created
# N some integer
ckpt_manager = FileCheckpointManager(...)
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=N)
And to restore from this checkpoint, at any time you can call:
state = iterative_process.initialize()
ckpt_manager = FileCheckpointManager(...)
restored_state = ckpt_manager.load_latest_checkpoint(
ServerState.from_anon_tuple(state))
One thing to note: the code pointers linked above are generally in tff.python.research...
, which is not included in the pip package; so the preferred way to get at them is to either fork the code into your own project, or pull down the repo and build it from source.
Thanks for your interest in TFF!