2

I am using Tensorflow Federated to train a text classification model with the federated learning approach. Is there any way to apply Early Stopping on the client-side? Is there an option for cross-validation in the API? The only thing I was able to find is the evaluation:

evaluation = tff.learning.build_federated_evaluation(model_fn)

Which is applied to the model by the end of a federated training round.

Am I missing something?

p20
  • 41
  • 3

1 Answers1

3

One straightforward way to control the number of steps a client takes when using tff.learning.build_federated_averaging_process is by setting up each clients tf.data.Dataset with different parameters. For example limiting the number of steps with tf.data.Dataset.take. The guide tf.data: Build TensorFlow input pipelines has many more details.

Alternatively stopping based on a measurement of learning progress would require modifying some internals of the algorithm currently. Rather than using the APIs in tff.learning, it maybe simpler to poke around federated/tensorflow_federated/python/examples/simple_fedavg/ particularly the client training loop is here and could be modified to stop based on some criteria other than "end of dataset" (as currently used).

Zachary Garrett
  • 2,911
  • 15
  • 23