1

I designed a function to calculate gradients from loss and model.trainable_variables with Tensorflow GardientTape. I used this function to perform split-learning, which means the model is devided and trained on a client up to a specific layer. The output, labels and trainablevariables of the clientside model are sent to a server to complete the training on the second half of the model. On the server, a serverside gradient and a clientside gradient, which should be sent back to the client to update the clientside model should be calculated with this function on the server:

def calc_gradients(self, msg):
    with tf.GradientTape(persistent=True) as tape:
        output_client, labels, trainable_variables_client = msg["client_out"], msg["label"], msg["trainable_variables"]
        output_server = self.call(output_client)

        l = self.loss(output_server, labels)  
        gradient_server = tape.gradient(l, self.model.trainable_variables)  

        print(l)
        print(trainable_variables_client)

        gradient_client = tape.gradient(l, trainable_variables_client)

        print("client gradient:")
        print(gradient_client)

The serverside gradient is calculated correctly. The loss is also calculated correctly and the trainable_variables of the server are recieved correctly, but the clientside gradient gradient_client = tape.gradient(l, trainable_variables_client) returns only:

client gradient:
[None, None, None, None]

The msg is a dictionary with the data,which is sent form the client to the server:

def start_train(self, batchround):
    if batchround in range(self.num_samples // self.batch_size): 
        with tf.GradientTape(persistent=True) as tape:
            output_client, labels = self.send(batchround) 

            client_trainable_variables = self.model.trainable_variables
            msg = {
                'client_out': output_client,
                'label': labels,
                'trainable_variables': client_trainable_variables,
                'batchround': batchround
            }
Melvin
  • 21
  • 2
  • The server model is used in computing the loss, but the client model doesn't contribute to it. Therefore, no gradients exist for the client model with respect to the loss. – Susmit Agrawal Jun 07 '20 at 16:01

0 Answers0