@tf.function
def update(model, dataset, weights, optimizer):
trainable_weights = model.trainable_variables
tf.nest.map_structure(lambda x, y: x.assign(y),
trainable_weights, weights)
for batch in dataset:
with tf.GradientTape() as tape:
outputs = model.forward_pass(batch)
grads = tape.gradient(outputs.loss, trainable_weights)
norm = tf.linalg.global_norm(grads)
grads_and_vars = zip(grads, client_weights)
optimizer.apply_gradients(grads_and_vars)
return trainable_weights, grads, norm
its returns none and error for grads, norm. Error it must be decalred before loop. I want to compute norm of each client and compare the norm of them.