I'll answer my own question, as it may help someone in the future.
The 'overriding' approach was not the correct one.
The label encoding and decoding steps are pre- and post-processing steps. As such, they should not be 'shoehorned' in the fit()
and predict()
methods, but rather be added as additional layers in the Sequential
model.
This keeps concerns separated and doesn't hide the pre- and post-processing steps, as they'll be visible when one inspects a loaded model via tf.keras.Model.summary()
, for example.
I ended following a two step approach:
- Training: I create a label encoder object that takes care of encoding the original labels into a 'one-hot-encoded' 2D array. I've used a
keras.layers.IntegerLookup
object to accomplish this. I then pass the original labels to this label encoder, and simply fit()
the model with the encoded labels.
- Inference: After training the model, I create an 'inference' version of the model (perhaps a better term should be 'pipeline' instead of 'model'), by adding two post-processing layers to it: (a) a custom
argmax
layer that extracts the encoded labels with highest probability; and (b) a label decoding layer (also based a keras.layers.IntegerLookup
object) that essentially does the opposite of the pre-processing object I've used in step 1.
After step 2, I can save the 'inference' version of the model using keras.models.save_model()
, which includes the post-processing layers. When load()
is called, all I have to do is call predict()
, which directly provides me an array with the predicted class labels, in their original format.
In order to implement the argmax
layer, I had to implement a custom Keras
layer, as shown in the example at the bottom.
For reference, here's a concrete example:
import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.datasets import mnist
from keras import layers
class ArgMax(tf.keras.layers.Layer):
"""
Custom Keras layer that extracts the labels from
an array of probabilities per label.
"""
def __init__(self):
super(ArgMax, self).__init__()
def call(self, inputs):
return tf.math.argmax(inputs, axis=1)
def load_dataset(discard:list=[]):
"""
Loads mnist dataset, filters out unwanted labels and re-shapes arrays.
"""
(X_tr, y_tr), (X_val, y_val) = mnist.load_data()
X_tr = X_tr[~np.isin(y_tr, discard),:]
y_tr = y_tr[~np.isin(y_tr, discard)]
X_val = X_val[~np.isin(y_val, discard),:]
y_val = y_val[~np.isin(y_val, discard)]
NUM_ROWS = X_tr.shape[1]
NUM_COLS = X_tr.shape[2]
X_tr = X_tr.reshape((X_tr.shape[0], NUM_ROWS * NUM_COLS))
X_val = X_val.reshape((X_val.shape[0], NUM_ROWS * NUM_COLS))
X_tr = X_tr.astype('float32') / 255
X_val = X_val.astype('float32') / 255
return (X_tr, y_tr), (X_val, y_val)
if __name__ == "__main__":
# load dataset : discard some of the labels
# to test correct operation of pre- and post-processing layers
(X_tr, y_tr), (X_val, y_val) = load_dataset(discard=[1, 3, 5])
# label pre-processing
label_preprocessing = layers.IntegerLookup(
output_mode="one_hot",
num_oov_indices=0
)
label_preprocessing.adapt(y_tr)
print(f"vocabulary : {label_preprocessing.get_vocabulary()}")
print(f"vocabulary size : {len(label_preprocessing.get_vocabulary())}")
# label post-processing
label_postprocessing = layers.IntegerLookup(
num_oov_indices=0,
invert=True
)
label_postprocessing.adapt(y_tr)
print(f"vocabulary : {label_postprocessing.get_vocabulary()}")
print(f"vocabulary size : {len(label_postprocessing.get_vocabulary())}")
# create model using Sequential API
model = Sequential()
model.add(tf.keras.layers.Dense(512, activation='relu', input_shape=(X_tr.shape[1],)))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dropout(0.25))
model.add(tf.keras.layers.Dense(len(np.unique(y_tr)), activation='softmax'))
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
# fit the model using the pre-processed labels
model.fit(X_tr, label_preprocessing(y_tr),
batch_size=128,
epochs=10,
verbose=1,
validation_data=(X_val, label_preprocessing(y_val)))
# create model for inference, i.e., with 2 post-processing layers:
# - add a layer that does argmax() operation
# - add a layer to invert the integer labels
model.add(ArgMax())
model.add(label_postprocessing)
# save the model
model.save('inference_model')
# load the model
loaded_model = tf.keras.models.load_model('inference_model')
# compare the first 20 predictions of the loaded model to the ground truth
print(loaded_model.predict(X_val[:20]))
print(y_val[:20])