This is definitely possible. To do so, you would just need to write a tff.federated_computation
that returns the CLIENTS
-placed model weights.
For brevity, I'll illustrate this in a much simpler setting, but the same principle applies to model training. For example, let's say that for each client, I'm going to take some integer broadcast from the server, and add it to the client's locally held integer, and return the results. I could do:
@tff.tf_computation(tf.int32, tf.int32)
def add(x,y):
return x + y
server_int_type = tff.types.at_server(tf.int32)
client_int_type = tff.types.at_clients(tf.int32)
@tff.federated_computation(server_int_type, client_int_type)
def add_across_clients(x, y):
x_at_clients = tff.federated_broadcast(x)
return tff.federated_map(add, (x_at_clients, y))
Then, the add_across_clients(3, [1, 2, 5])
will return the value [4, 5, 8]
. In other words, it is returning a tff.CLIENTS
-placed value, representing the collection as a list.
You can do the same kind of thing with model training code. Broadcast some weights, apply local training (via tff.federated_map
) and return the result.