I have followed this emnist tutorial to create an image classification experiment (7 classes) with the aim of training a classifier on 3 silos of data with the TFF framework.
Before training begins, I convert the model to a tf keras model using tff.learning.assign_weights_to_keras_model(model,state.model)
to evaluate on my validation set. Regardless of the label, the model only predicts one class. This is to be expected as no training of the model has occurred yet. However, I repeat this step after each federated averaging round and the problem persists. All validation images are predicted to one class. I also save the tf keras model weights after each round and make predictions on the test set - no changes.
Some of the steps I have taken to check the source of the issue:
- Checked if the tf keras model weights are updating when the FL model is converted after each round - they are updating.
- Ensured that the buffer size is greater than the training dataset size for each client.
- Compared the predictions to the class distribution in the training datasets. There is a class imbalance but the one class that the model predicts is not necessarily the majority class. Also, it is not always the same class. For the most part, it predicts only class 0.
- Increased the number of rounds to 5 and epochs per round to 10. This is computationally very intensive as it is quite a large model being trained with approx 1500 images per client.
- Investigated the TensorBoard logs from each training attempt. The training loss is decreasing as the round progresses.
- Tried a much simpler model - basic CNN with 2 conv layers. This allowed me to greatly increase the number of epochs and rounds. When evaluating this model on the test set, it predicted 4 different classes but the performance remains very bad. This would indicate that I just would need to increase the number of rounds and epochs for my original model to increase the variation in predictions. This is difficult due the large training time that would be a result.
Model details:
The model uses the XceptionNet as the base model with the weights unfrozen. This performs well on the classification task when all the training images are pooled into a global dataset. Our aim is to hopefully achieve a comparable performance with FL.
base_model = Xception(include_top=False,
weights=weights,
pooling='max',
input_shape=input_shape)
x = GlobalAveragePooling2D()( x )
predictions = Dense( num_classes, activation='softmax' )( x )
model = Model( base_model.input, outputs=predictions )
Here is my training code:
def fit(self):
"""Train FL model"""
# self.load_data()
summary_writer = tf.summary.create_file_writer(
self.logs_dir
)
federated_averaging = self._construct_iterative_process()
state = federated_averaging.initialize()
tfkeras_model = self._convert_to_tfkeras_model( state )
print( np.argmax( tfkeras_model.predict( self.val_data ), axis=-1 ) )
val_loss, val_acc = tfkeras_model.evaluate( self.val_data, steps=100 )
with summary_writer.as_default():
for round_num in tqdm( range( 1, self.num_rounds ), ascii=True, desc="FedAvg Rounds" ):
print( "Beginning fed avg round..." )
# Round of federated averaging
state, metrics = federated_averaging.next(
state,
self.training_data
)
print( "Fed avg round complete" )
# Saving logs
for name, value in metrics._asdict().items():
tf.summary.scalar(
name,
value,
step=round_num
)
print( "round {:2d}, metrics={}".format( round_num, metrics ) )
tff.learning.assign_weights_to_keras_model(
tfkeras_model,
state.model
)
# tfkeras_model = self._convert_to_tfkeras_model(
# state
# )
val_metrics = {}
val_metrics["val_loss"], val_metrics["val_acc"] = tfkeras_model.evaluate(
self.val_data,
steps=100
)
for name, metric in val_metrics.items():
tf.summary.scalar(
name=name,
data=metric,
step=round_num
)
self._checkpoint_tfkeras_model(
tfkeras_model,
round_num,
self.checkpoint_dir
)
def _checkpoint_tfkeras_model(self,
model,
round_number,
checkpoint_dir):
# Obtaining model dir path
model_dir = os.path.join(
checkpoint_dir,
f'round_{round_number}',
)
# Creating directory
pathlib.Path(
model_dir
).mkdir(
parents=True
)
model_path = os.path.join(
model_dir,
f'model_file_round{round_number}.h5'
)
# Saving model
model.save(
model_path
)
def _convert_to_tfkeras_model(self, state):
"""Converts global TFF modle of TF keras model
Takes the weights of the global model
and pushes them back into a standard
Keras model
Args:
state: The state of the FL server
containing the model and
optimization state
Returns:
(model); TF Keras model
"""
model = self._load_tf_keras_model()
model.compile(
loss=self.loss,
metrics=self.metrics
)
tff.learning.assign_weights_to_keras_model(
model,
state.model
)
return model
def _load_tf_keras_model(self):
"""Loads tf keras models
Raises:
KeyError: A model name was not defined
correctly
Returns:
(model): TF keras model object
"""
model = create_models(
model_type=self.model_type,
input_shape=[self.img_h, self.img_w, 3],
freeze_base_weights=self.freeze_weights,
num_classes=self.num_classes,
compile_model=False
)
return model
def _define_model(self):
"""Model creation function"""
model = self._load_tf_keras_model()
tff_model = tff.learning.from_keras_model(
model,
dummy_batch=self.sample_batch,
loss=self.loss,
# Using self.metrics throws an error
metrics=[tf.keras.metrics.CategoricalAccuracy()] )
return tff_model
def _construct_iterative_process(self):
"""Constructing federated averaging process"""
iterative_process = tff.learning.build_federated_averaging_process(
self._define_model,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD( learning_rate=0.02 ),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD( learning_rate=1.0 ) )
return iterative_process