0
@tf.function
def update(model, dataset, weights, optimizer):

  trainable_weights = model.trainable_variables

  tf.nest.map_structure(lambda x, y: x.assign(y),
                    trainable_weights, weights)

  for batch in dataset:
  
     with tf.GradientTape() as tape:
        
          outputs = model.forward_pass(batch)


    grads = tape.gradient(outputs.loss, trainable_weights)
    norm = tf.linalg.global_norm(grads)

    grads_and_vars = zip(grads, client_weights)

  
    optimizer.apply_gradients(grads_and_vars)

 return trainable_weights, grads, norm

its returns none and error for grads, norm. Error it must be decalred before loop. I want to compute norm of each client and compare the norm of them.

0 Answers0