0

I am in the process of implementing ResNet50 model on medical images of four classes. I intially had a dataset of 250 images per class and I split them into 2 folders train and val and used val data as test dataset and 80% of train as training dataset and 20% as validation dataset.

I tried a code which I have found online. I have experimented a lot but nothing improves the validation dataset accuracy although the training accuracy is decent, which can be improved later.

Please suggest me ways to improve the validation accuracy with respect to my problem statemnt. PS: Although the images are black and white, I have used the input shape to be (224,224,3), because I couldn't translate the code I found for grayscale images. Hope that's not the major issue here.

Reference code: https://github.com/anujshah1003/Transfer-Learning-in-keras---custom-data/blob/master/transfer_learning_resnet50_custom_data.py

The only changes I did were changing the directory of datasets and also excluded the Flatten Layer in my code because avg_pool layer is also flattened so could directly apply the Dense layer.

last_layer = model.get_layer('avg_pool').output
out = Dense(num_classes, activation='softmax', name='output_layer')(last_layer)
custom_resnet_model = Model(inputs=image_input,outputs= out)

t=time.time()
hist = custom_resnet_model.fit(X_train, y_train, batch_size=32, epochs=12, verbose=1, 
validation_data=(X_test, y_test))
print('Training time: %s' % (t - time.time()))
(loss, accuracy) = custom_resnet_model.evaluate(X_test, y_test, batch_size=10, verbose=1)
print("[INFO] loss={:.4f}, accuracy: {:.4f}%".format(loss,accuracy * 100))

Output after 12 epochs

Insomniac
  • 639
  • 3
  • 19

1 Answers1

4

A very common case. Your model is not able to generalize the data. You can try the following steps to overcome this:

  1. Collecting more data: Try to collect more variations of the data. It will help your model to generalize classes.
  2. Augmentation: A very common and very useful technique. Try different angles, different contrast, zooming in-out, etc.

If still you face the same issue, plot the confusion matrix to see where your model suffers most. Then you can analyze the data for those specific classes.

Btw, for the greyscale image issue, you can take a look here: How can I use a pre-trained neural network with grayscale images?

Nazmul Hasan
  • 860
  • 6
  • 17