3

I am building a multi-classification federated learning model using TensorFlow. And I want to generate a confusion matrix for my model, but I don't know how to find the y_true and y_pred in my federated computation code. The federated computation code:

def train(NUM_ROUNDS, data_frame):
  state = iterative_process.initialize()
  for round_num in range(0, NUM_ROUNDS):
    train_metrics = eval_process(state.model, test_data)['eval']
    state, _= iterative_process.next(state, train_data)
    print(f'Round {round_num:3d}: {train_metrics}')
    data_frame = data_frame.append({'Round': round_num,
                                      **train_metrics}, ignore_index=True)
  

  test_metrics = eval_process(state.model, test_data)
  print("The final evaluation is: ")
  print(test_metrics)

  return data_frame

data_frame = pd.DataFrame()
NUM_ROUNDS = 2

print(f'Starting training')
data_frame = train(NUM_ROUNDS, data_frame)
print()
Starting training
Round   0: OrderedDict([('sparse_categorical_accuracy', 0.12227074), ('loss', 1.3862933), ('num_examples', 916), ('num_batches', 184)])
Round   1: OrderedDict([('sparse_categorical_accuracy', 0.57969433), ('loss', 1.7442805), ('num_examples', 916), ('num_batches', 184)])
The final evaluation is: 
OrderedDict([('eval', OrderedDict([('sparse_categorical_accuracy', 0.17467248), ('loss', 1.7451892), ('num_examples', 916), ('num_batches', 184)]))])

The confusion matrix code is:

classes=[0,1,2,3]
logdir='log'
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

con_mat = tf.math.confusion_matrix(labels=y_true, predictions=y_pred).numpy()
con_mat_norm = np.around(con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis], decimals=2)

con_mat_df = pd.DataFrame(con_mat_norm,
                     index = classes, 
                     columns = classes)

figure = plt.figure(figsize=(8, 8))
sns.heatmap(con_mat_df, annot=True,cmap=plt.cm.Blues)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

so, is this the right way to generate a confusion matrix for federated learning and how can I find the y_true to pass it to the function?

Eden
  • 325
  • 3
  • 13

0 Answers0