0

I implemented Resnet34 model in federated images classification tutorial. After 10 rounds the training accuracy can be higher than 90%, however, the evaluation accuracy using the last round's state.model is always around 50%.

    evaluation = tff.learning.build_federated_evaluation(model_fn)
    federated_test_data = make_federated_data(emnist_test, sample_clients)
    test_metrics = evaluation(state.model, federated_test_data)
    str(test_metrics)

I am very confused what's possibly wrong with the evaluation part? Also, I printed the untrainable variables (mean and variance in BatchNorm) of the server's model, which are 0 and 1 with no updates/averaging after those rounds. Should they be like that or that could be the problem? Thanks very much!

Updates:

The codes to prepare training data and printed results:

len(emnist_train.client_ids)
4

emnist_train.element_type_structure
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int64, name=None)),('pixels',TensorSpec(shape=(256, 256, 3), dtype=tf.float32, name=None))])


NUM_CLIENTS = 4
NUM_EPOCHS = 1
BATCH_SIZE = 30
SHUFFLE_BUFFER = 500

def preprocess(dataset):
  def element_fn(element):
    return collections.OrderedDict([
        ('x', element['pixels']),
        ('y', tf.reshape(element['label'], [1])),
    ])
  return dataset.repeat(NUM_EPOCHS).map(element_fn).shuffle(
      SHUFFLE_BUFFER).batch(BATCH_SIZE)



sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(
    lambda x: x.numpy(), iter(preprocessed_example_dataset).next())

def make_federated_data(client_data, client_ids):
      return [preprocess(client_data.create_tf_dataset_for_client(x))
          for x in client_ids]



len(federated_train_data), federated_train_data[0]
(4,<BatchDataset shapes: OrderedDict([(x, (None, 256, 256, 3)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int64)])>)

The training and evaluation codes:

 def create_compiled_keras_model():
  base_model = tf.keras.applications.resnet.ResNet50(include_top=False, weights='imagenet', input_shape=(256,256,3,))
  global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
  prediction_layer = tf.keras.layers.Dense(2, activation='softmax')

  model = tf.keras.Sequential([
                               base_model,
                               global_average_layer,
                               prediction_layer
                               ])
  model.compile(optimizer = tf.keras.optimizers.SGD(lr = 0.001, momentum=0.9), loss = tf.keras.losses.SparseCategoricalCrossentropy(), metrics = [tf.keras.metrics.SparseCategoricalAccuracy()])
  return model

def model_fn():
  keras_model = create_compiled_keras_model()
  return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state = iterative_process.initialize()
for round_num in range(2, 12):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics, state))


evaluation = tff.learning.build_federated_evaluation(model_fn)
federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(4,
 <BatchDataset shapes: OrderedDict([(x, (None, 256, 256, 3)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int64)])>)

test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)

The training and evaluations results after each round:

round  1, metrics=<sparse_categorical_accuracy=0.5089045763015747,loss=0.7813001871109009,keras_training_time_client_sum_sec=0.008826255798339844>

<sparse_categorical_accuracy=0.49949443340301514,loss=8.0671968460083,keras_training_time_client_sum_sec=0.0>


round  2, metrics=<sparse_categorical_accuracy=0.519825279712677,loss=0.7640910148620605,keras_training_time_client_sum_sec=0.011750459671020508>

<sparse_categorical_accuracy=0.49949443340301514,loss=8.0671968460083,keras_training_time_client_sum_sec=0.0>


round  3, metrics=<sparse_categorical_accuracy=0.5099126100540161,loss=0.7513422966003418,keras_training_time_client_sum_sec=0.0039823055267333984>

<sparse_categorical_accuracy=0.49949443340301514,loss=8.0671968460083,keras_training_time_client_sum_sec=0.0>


round  4, metrics=<sparse_categorical_accuracy=0.5278897881507874,loss=0.7905193567276001,keras_training_time_client_sum_sec=0.0010638236999511719>

<sparse_categorical_accuracy=0.49949443340301514,loss=8.0671968460083,keras_training_time_client_sum_sec=0.0>


round  5, metrics=<sparse_categorical_accuracy=0.5199933052062988,loss=0.7782396674156189,keras_training_time_client_sum_sec=0.012729644775390625>

<sparse_categorical_accuracy=0.49949443340301514,loss=8.0671968460083,keras_training_time_client_sum_sec=0.0>
Eduardo Yáñez Parareda
  • 9,126
  • 4
  • 37
  • 50
miaoz18
  • 41
  • 2

2 Answers2

0

There are a few nuances and a few open research problems in Federated Learning and this question has struck a couple of them.

  1. Training loss looks much better than evaluation loss: when using Federated Averaging (the optimization algorithm used in the Federated Learning for Image Classification tutorial) one needs to be careful interpreting metrics as they have nuanced differences from centralized model training. Especially training loss, which is the average over many sequence steps or batches. This means after one round, each client may have fit the model to their local data very well (obtaining a high accuracy), but after averaging these updates into the global model the global model may still be far away from "good", resulting in a low test accuracy. Additionally, 10 rounds may be too few; one of the original academic papers on Federated Learning demonstrated at least 20 rounds until 99% accuracy (McMahan 2016) with IID data, and more than 100 rounds in with non-IID data.

  2. BatchNorm in the federated setting: its an open research problem on how to combine the batchnorm parameters, particularly with non-IID client data. Should each new client start with fresh parameters, or receive the global model parameters? TFF may not be communicating them between the server and client (since it currently is implemented only to communicate trainable variables), and may be leading to unexpected behavior. It may we good to print the state parameters watch what happens each round to them.

Zachary Garrett
  • 2,911
  • 15
  • 23
  • Thanks a lot for your reply. I feel the evaluation loss problem comes from a bug instead of num of rounds, since the evaluation accuracy and loss don't change at all (exactly the same) in rounds from 1 to 40... I am very confused since the training accuracy on local clients look good(> 90%) so there shouldn't be bugs in the customization model part, and I just used the listed evaluation method according to tutorial. It's a good way to print the `state`( server state ). I printed them out and check all trainable variables, but nothing looks wrong to me... Any hints on solving the bug? Thanks!! – miaoz18 Feb 18 '20 at 05:55
  • Hmm, I'm not sure with the current information. How about extending the question with the code that performs training, and the values seen printed each round? This could make diagnosing the problem easier. – Zachary Garrett Feb 18 '20 at 16:42
  • The codes and results printed each round have been attached to the question! – miaoz18 Feb 19 '20 at 19:07
  • And to be mentioned that `emnist_train ` and `emnist_test ` here are Retina data I loaded from h5 files, instead of emnist data (I should have changed them to proper names...) – miaoz18 Feb 19 '20 at 19:10
0

I found that the initialization is the reason why ResNet has poor performance. It is possibly because that ttf uses relatively simple state initialization which doesn't consider some layers like batch norm, so when I assigned the normal Keras model initial weights to the server instead of using its default initialization, the federated results were much better.

miaoz18
  • 41
  • 2
  • I'm having the same problem you were. For your solution, what do you mean by this comment: "I assigned the normal Keras model initial weights to the server instead of using its default initialization" How did you do this? – Kane Sep 15 '20 at 11:45
  • @miaoz18 Please can you tell us how you solve the problem ? – seni Nov 16 '20 at 09:22
  • I have the same problem. How to assign the normal Keras model initial weights to the server? – tfreak Feb 15 '21 at 02:05