0

In federated learning, I want to get weights of each local model every round, then I will cluster local clients based on their weights, but I can just use training_process.get_model_weights(train_state) to get global weights only.

I did use training_process.get_model_weights(train_state) to get global weights, but I haven't found any library or function to get weights of each clients yet.

baodvo
  • 1

1 Answers1

0

This is definitely possible. To do so, you would just need to write a tff.federated_computation that returns the CLIENTS-placed model weights.

For brevity, I'll illustrate this in a much simpler setting, but the same principle applies to model training. For example, let's say that for each client, I'm going to take some integer broadcast from the server, and add it to the client's locally held integer, and return the results. I could do:

@tff.tf_computation(tf.int32, tf.int32)
def add(x,y):
  return x + y

server_int_type = tff.types.at_server(tf.int32)
client_int_type = tff.types.at_clients(tf.int32)

@tff.federated_computation(server_int_type, client_int_type)
def add_across_clients(x, y):
  x_at_clients = tff.federated_broadcast(x)
  return tff.federated_map(add, (x_at_clients, y))

Then, the add_across_clients(3, [1, 2, 5]) will return the value [4, 5, 8]. In other words, it is returning a tff.CLIENTS-placed value, representing the collection as a list.

You can do the same kind of thing with model training code. Broadcast some weights, apply local training (via tff.federated_map) and return the result.