I have designed the Federated Learning model with TensorFlow Federated framework. Defined the iterative process as below,
iterative_process = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.9))
I have 2 remote workers running the tffruntime remote executor service and the context for running computation is defined as tff.backends.native.set_remote_python_execution_context(channels)
. When the model is broadcasted to the client with iterative_process.next(state, train_data)
, how can we identify that the client metrics is aggregated and applied to the server model. Is the single api build_federated_averaging_process
is enough to get the metrics from clients, aggregate and then update the server model? If means how can we identify that the server model is updated? Can anyone please help me to understand this.