I have this model that contains a MobileNetV2:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot
feature_extractor = tf.keras.applications.MobileNetV2(
include_top=False,
weights='imagenet',
alpha=1.0,
input_shape=(224,224,3)
)
feature_extractor.trainable=False
x = tf.keras.layers.GlobalAveragePooling2D()(feature_extractor.output)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Dense(32)(x)
x = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(feature_extractor.input, x)
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_dataset, epochs=4, validation_data=test_dataset)
It gets around 98% accuracy, but the problem is that after performing quantization-aware training it's accuracy drops to around 50%:
quantize_model = tfmot.quantization.keras.quantize_model
q_aware_model = quantize_model(model)
q_aware_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
q_aware_model.fit(train_dataset, epochs=10, validation_data=test_dataset)
Accuracy before quantization-aware training: pre-QAT
Accuracy after quantization-aware training: Post-QAT
I tried training it for more epochs but accuracy doesn't improve. I also tried quantizing only dense layers as described here but it didn't work either.