Search code examples
tensorflowmachine-learningtensorflow2.0tensorflow-federatedfederated-learning

MSE error different during training and evaluation in tensorflow federated


I am implementing a regression model in tensorflow federated. I started with a simple model used in this tutorial for keras: https://www.tensorflow.org/tutorials/keras/regression

I changed the model to use federated learning. Here is my model:

import pandas as pd
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_federated as tff

dataset_path = keras.utils.get_file("auto-mpg.data", "http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data")

column_names = ['MPG','Cylinders','Displacement','Horsepower','Weight',
                'Acceleration', 'Model Year', 'Origin']
raw_dataset = pd.read_csv(dataset_path, names=column_names,
                      na_values = "?", comment='\t',
                      sep=" ", skipinitialspace=True)

df = raw_dataset.copy()
df = df.dropna()
dfs = [x for _, x in df.groupby('Origin')]


datasets = []
targets = []
for dataframe in dfs:
    target = dataframe.pop('MPG')

    from sklearn.preprocessing import StandardScaler
    standard_scaler_x = StandardScaler(with_mean=True, with_std=True)
    normalized_values = standard_scaler_x.fit_transform(dataframe.values)

    dataset = tf.data.Dataset.from_tensor_slices(({ 'x': normalized_values, 'y': target.values}))
    train_dataset = dataset.shuffle(len(dataframe)).repeat(10).batch(20)
    test_dataset = dataset.shuffle(len(dataframe)).batch(1)
    datasets.append(train_dataset)


def build_model():
  model = keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=[7]),
    layers.Dense(64, activation='relu'),
    layers.Dense(1)
  ])
  return model
dataset_path


import collections


model = build_model()

sample_batch = tf.nest.map_structure(
    lambda x: x.numpy(), iter(datasets[0]).next())

def loss_fn_Federated(y_true, y_pred):
    return tf.reduce_mean(tf.keras.losses.MSE(y_true, y_pred))

def create_tff_model():
  keras_model_clone = tf.keras.models.clone_model(model)
#   adam = keras.optimizers.Adam()
  adam = tf.keras.optimizers.SGD(0.002)
  keras_model_clone.compile(optimizer=adam, loss='mse', metrics=[tf.keras.metrics.MeanSquaredError()])
  return tff.learning.from_compiled_keras_model(keras_model_clone, sample_batch)

print("Create averaging process")
# This command builds all the TensorFlow graphs and serializes them: 
iterative_process = tff.learning.build_federated_averaging_process(model_fn=create_tff_model)

print("Initzialize averaging process")
state = iterative_process.initialize()

print("Start iterations")
for _ in range(10):
  state, metrics = iterative_process.next(state, datasets)
  print('metrics={}'.format(metrics))
Start iterations
metrics=<mean_squared_error=95.8644027709961,loss=96.28633880615234>
metrics=<mean_squared_error=9.511247634887695,loss=9.522096633911133>
metrics=<mean_squared_error=8.26853084564209,loss=8.277074813842773>
metrics=<mean_squared_error=7.975323677062988,loss=7.9771647453308105>
metrics=<mean_squared_error=7.618809700012207,loss=7.644164562225342>
metrics=<mean_squared_error=7.347906112670898,loss=7.340310096740723>
metrics=<mean_squared_error=7.210267543792725,loss=7.210223197937012>
metrics=<mean_squared_error=7.045553207397461,loss=7.045469760894775>
metrics=<mean_squared_error=6.861278533935547,loss=6.878870487213135>
metrics=<mean_squared_error=6.80275297164917,loss=6.817670822143555>
evaluation = tff.learning.build_federated_evaluation(model_fn=create_tff_model)


test_metrics = evaluation(state.model, datasets)
print(test_metrics)
<mean_squared_error=27.308320999145508,loss=27.19877052307129>

I am confused why the mse for evaluation after 10 iterations is higher for training set when the iterative process instead returns a much smaller mse. What am I doing wrong here? Is it something hidden in the implementation of fml in tensorflow? Can someone explain it to me?


Solution

  • You've actually hit on a very intersting phenomenon in federated learning. In particular, the questions that needs asking here is: how are the training metrics computed?

    Training metrics are generally computed during the local training; therefore they are computed as the client is fitting its local data; in TFF, they are computed before each local step is taken--this happens here during the forward pass call. If you imagine the extreme situation, where metrics were only computed at the end of a round of training on each client, you would see one thing clearly--that the client is reporting metrics that represent how well it has fit his local data.

    Federated learning, however, has to produce a single, global model at the end of each round of training--in federated averaging, these local models are averaged together in parameter space. In the general case, it is unclear how to intuitively interpret such a step--an average of nonlinear models in parameter space is not giving you an average prediction or anything like that.

    Federated evaluation takes this averaged model, and runs local evaluation on each client, without fitting the local data at all. Therefore if you are in a situation where your client datasets have quite different distributions, you should expect the metrics returned from federated evaluation to be quite different than those returned from a round of federated training--federated averaging is reporting metrics gathered during the process of adapting to the local data, while federated evaluation is reporting metrics gathered after averaging all these locally trained models together.

    Indeed, if you interleave calls to the next function of your iterative process and your evaluation function, you will see a pattern like this:

    train metrics=<mean_squared_error=88.22489929199219,loss=88.6319351196289>
    eval metrics=<mean_squared_error=33.69473648071289,loss=33.55160140991211>
    train metrics=<mean_squared_error=8.873666763305664,loss=8.882776260375977>
    eval metrics=<mean_squared_error=29.235883712768555,loss=29.13833236694336>
    train metrics=<mean_squared_error=7.932246208190918,loss=7.918393611907959>
    eval metrics=<mean_squared_error=27.9038028717041,loss=27.866817474365234>
    train metrics=<mean_squared_error=7.573018550872803,loss=7.576478958129883>
    eval metrics=<mean_squared_error=27.600923538208008,loss=27.561887741088867>
    train metrics=<mean_squared_error=7.228050708770752,loss=7.224897861480713>
    eval metrics=<mean_squared_error=27.46322250366211,loss=27.36537742614746>
    train metrics=<mean_squared_error=7.049572944641113,loss=7.03688907623291>
    eval metrics=<mean_squared_error=26.755760192871094,loss=26.719152450561523>
    train metrics=<mean_squared_error=6.983217716217041,loss=6.954374313354492>
    eval metrics=<mean_squared_error=26.756895065307617,loss=26.647253036499023>
    train metrics=<mean_squared_error=6.909178256988525,loss=6.923810005187988>
    eval metrics=<mean_squared_error=27.047882080078125,loss=26.86684799194336>
    train metrics=<mean_squared_error=6.8190460205078125,loss=6.79202938079834>
    eval metrics=<mean_squared_error=26.209386825561523,loss=26.10053062438965>
    train metrics=<mean_squared_error=6.7200140953063965,loss=6.737307071685791>
    eval metrics=<mean_squared_error=26.682661056518555,loss=26.64984703063965>
    

    That is, your federated evaluation is also going down, just much more slowly than your training metrics are--effectively measuring the variation in your client datasets. You can validate this by running:

    eval_metrics = evaluation(state.model, [datasets[0]])
    print('eval metrics on 0th dataset={}'.format(eval_metrics))
    eval_metrics = evaluation(state.model, [datasets[1]])
    print('eval metrics on 1st dataset={}'.format(eval_metrics))
    eval_metrics = evaluation(state.model, [datasets[2]])
    print('eval metrics on 2nd dataset={}'.format(eval_metrics))
    

    and you will see a result like

    eval metrics on 0th dataset=<mean_squared_error=9.426984786987305,loss=9.431192398071289>
    eval metrics on 1st dataset=<mean_squared_error=34.96992111206055,loss=34.96992492675781>
    eval metrics on 2nd dataset=<mean_squared_error=72.94075775146484,loss=72.88787841796875>
    

    so you can see that your averaged model has dramatically different performance across these three datasets.

    One final note: you may notice that the final result from your evaluate function is not the average of your three losses--this is because the evaluate function will be example-weighted, not client-weighted--that is, clients with more data get more weight in the average.

    Hope this helps!