I'm trying to follow the fine-tuning steps described in https://www.tensorflow.org/tutorials/images/transfer_learning#create_the_base_model_from_the_pre-trained_convnets to get a trained model for binary segmentation.
I create an encoder-decoder with the weights of the encoder being the ones of the MobileNetV2 and fixed as encoder.trainable = False
. Then, I define my decoder as said in the tutorial and I train the network for 300 epochs using a learning rate of 0.005. I get the following loss value and Jaccard index during the lasts epochs:
Epoch 297/300
55/55 [==============================] - 85s 2s/step - loss: 0.2443 - jaccard_sparse3D: 0.5556 - accuracy: 0.9923 - val_loss: 0.0440 - val_jaccard_sparse3D: 0.3172 - val_accuracy: 0.9768
Epoch 298/300
55/55 [==============================] - 75s 1s/step - loss: 0.2437 - jaccard_sparse3D: 0.5190 - accuracy: 0.9932 - val_loss: 0.0422 - val_jaccard_sparse3D: 0.3281 - val_accuracy: 0.9776
Epoch 299/300
55/55 [==============================] - 78s 1s/step - loss: 0.2465 - jaccard_sparse3D: 0.4557 - accuracy: 0.9936 - val_loss: 0.0431 - val_jaccard_sparse3D: 0.3327 - val_accuracy: 0.9769
Epoch 300/300
55/55 [==============================] - 85s 2s/step - loss: 0.2467 - jaccard_sparse3D: 0.5030 - accuracy: 0.9923 - val_loss: 0.0463 - val_jaccard_sparse3D: 0.3315 - val_accuracy: 0.9740
I store all the weights of this model and then, I compute the fine-tuning with the following steps:
model.load_weights('my_pretrained_weights.h5')
model.trainable = True
model.compile(optimizer=Adam(learning_rate=0.00001, name='adam'),
loss=SparseCategoricalCrossentropy(from_logits=True),
metrics=[jaccard, "accuracy"])
model.fit(training_generator, validation_data=(val_x, val_y), epochs=5,
validation_batch_size=2, callbacks=callbacks)
Suddenly the performance of my model is way much worse than during the training of the decoder:
Epoch 1/5
55/55 [==============================] - 89s 2s/step - loss: 0.2417 - jaccard_sparse3D: 0.0843 - accuracy: 0.9946 - val_loss: 0.0079 - val_jaccard_sparse3D: 0.0312 - val_accuracy: 0.9992
Epoch 2/5
55/55 [==============================] - 90s 2s/step - loss: 0.1920 - jaccard_sparse3D: 0.1179 - accuracy: 0.9927 - val_loss: 0.0138 - val_jaccard_sparse3D: 7.1138e-05 - val_accuracy: 0.9998
Epoch 3/5
55/55 [==============================] - 95s 2s/step - loss: 0.2173 - jaccard_sparse3D: 0.1227 - accuracy: 0.9932 - val_loss: 0.0171 - val_jaccard_sparse3D: 0.0000e+00 - val_accuracy: 0.9999
Epoch 4/5
55/55 [==============================] - 94s 2s/step - loss: 0.2428 - jaccard_sparse3D: 0.1319 - accuracy: 0.9927 - val_loss: 0.0190 - val_jaccard_sparse3D: 0.0000e+00 - val_accuracy: 1.0000
Epoch 5/5
55/55 [==============================] - 97s 2s/step - loss: 0.1920 - jaccard_sparse3D: 0.1107 - accuracy: 0.9926 - val_loss: 0.0215 - val_jaccard_sparse3D: 0.0000e+00 - val_accuracy: 1.0000
Is there any known reason why this is happening? Is it normal? Thank you in advance!