I designed a function to calculate gradients from loss and model.trainable_variables
with Tensorflow GardientTape
. I used this function to perform split-learning, which means the model is devided and trained on a client up to a specific layer. The output, labels and trainablevariables of the clientside model are sent to a server to complete the training on the second half of the model. On the server, a serverside gradient and a clientside gradient, which should be sent back to the client to update the clientside model should be calculated with this function on the server:
def calc_gradients(self, msg):
with tf.GradientTape(persistent=True) as tape:
output_client, labels, trainable_variables_client = msg["client_out"], msg["label"], msg["trainable_variables"]
output_server = self.call(output_client)
l = self.loss(output_server, labels)
gradient_server = tape.gradient(l, self.model.trainable_variables)
print(l)
print(trainable_variables_client)
gradient_client = tape.gradient(l, trainable_variables_client)
print("client gradient:")
print(gradient_client)
The serverside gradient is calculated correctly.
The loss is also calculated correctly and the trainable_variables of the server are recieved correctly, but the clientside gradient gradient_client = tape.gradient(l, trainable_variables_client)
returns only:
client gradient:
[None, None, None, None]
The msg is a dictionary with the data,which is sent form the client to the server:
def start_train(self, batchround):
if batchround in range(self.num_samples // self.batch_size):
with tf.GradientTape(persistent=True) as tape:
output_client, labels = self.send(batchround)
client_trainable_variables = self.model.trainable_variables
msg = {
'client_out': output_client,
'label': labels,
'trainable_variables': client_trainable_variables,
'batchround': batchround
}