1

I have designed the Federated Learning model with TensorFlow Federated framework. Defined the iterative process as below,

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.9))

I have 2 remote workers running the tffruntime remote executor service and the context for running computation is defined as tff.backends.native.set_remote_python_execution_context(channels). When the model is broadcasted to the client with iterative_process.next(state, train_data), how can we identify that the client metrics is aggregated and applied to the server model. Is the single api build_federated_averaging_process is enough to get the metrics from clients, aggregate and then update the server model? If means how can we identify that the server model is updated? Can anyone please help me to understand this.

Eden
  • 325
  • 3
  • 13

1 Answers1

0

The build_federated_averaging_process API builds an iterative process of the full federated learning steps. If you want to verify that the server model is updated, you can print state.model after each iterative_process.next(state, train_data).

Wennan Zhu
  • 136
  • 5
  • ````NUM_ROUNDS = 20 for round_num in range(0, NUM_ROUNDS): state, metrics = iterative_process.next(state, train_data) print(server.model)```` I am getting error as ‘server’ is not defined.Is this the right way to check the server model? Or can we check only the wights assigned to the model as ````state.model```` – crazynovatech Apr 12 '22 at 06:17
  • Yeah, it should be `state.model`, sorry about the typo. – Wennan Zhu Apr 12 '22 at 15:04