We are working on semantic segmenation for our thesis, aiming the goal to stratify lesions in T2W 2d-images of the prostate. What we're dealing with is a 3 class problem, having following labels: 0 = background, 1 = insignificant tumor, 2 = significant tumor. We create our segmentation model using U-Net from Segmentation Model's, link https://github.com/qubvel/segmentation_models
DATA: Our data consists of images from 200 patients, where each associated image/images to the patient represent a suscpicuous lesion. Some patients only have one image, but the max is 3. Furthermore, each image is sliced into 24 slices, which should represent the whole lesion in 3D. And of course we have the corresponding masks for each "lesion image", note that some masked slices of the lesion could have no masks at all, because of f.ex the its size is very small in volum and so on. And there many instances of these were the mask slices are just background masks. All in all, these masked slices are labelled from 0 to 2, dependent on their content as previously explained.
PREPROCESSING: The T2W slices are normalized, applied n4 bias field correction on (using SimpleITK's). The lesion masks are one hot encoded using keras backend, and also therefore not normalized.
The lesion masks are ndarrays, so we have checked the pixel distribution for each class using np.unique, and there was a huge imbalance, a big partition was background pixels. The results were: array([491101591, 262695, 155714])), so we decieded to remove the "only-background" mask slices, having in mind that this could result in problems when training. Still the pixel nr was bigger than class 1 and 2, but this was logically fine for us, considering how the data is.
HYPER PARAMETERS: The metric we use for now is dice score, which is implemented by ourself for a multiclass case. Our loss is for that reason the dice loss. From our dice score function we retrieve dice scores for each class, and also have another dice score function for the overall/global.
epoch = 130, lr = 1e-5, batch size = 32. The learning rate is also regularized using tf.keras.callbacks.ReduceLROnPlateau()
We are also using resnet34 as backbone to our Unet model, weights from imagenet, activation as softmax and optimzer as tf.keras.optimizers.Adam().
TRAINING PIPELINE: We have decided to to do kfold cross validation for training and testing our data. Before training/fitting the model, we end up having 3 datasets; train set, "internal" test set, and prediction set (this is not so interesting, as we have not came this far yet). They are also made sure to have correct shape (slices, h, w, channel). Then the background only slices are removed, and at last we have also done some augmnentation using using Albumentations.ai library, here; flipping, rotating and translation. These augmented data is then added to their respective t2w images and masks sets for both training and validation sets. We did this to obtain more data, as we struggled alot with overfitting.
The training goes slow but after some epochs, we get good dice score for each class and overall, also making the dice loss small pretty fast, but for validation its another story. The validation metrics for the not-backgroiund classes (signinficant/insig. lesion, ie 1/2) seems to be fluctating , going from values in range e-1 and e-4, the validation dice score for class 0 is pretty good, as it converges to almost 1. In general our val loss wont go below 0.6. So we are suspecting overfitting, but this combinted with an oscillating behaviour of val metrics. Compared to training dice score for class 1 and 2, with them usually converging to 0.9 ++, for class 1 almost 0.99. This fluctation also doesnt have to happen every other epoch, there is no pattern, just happens randomly from epoch to epoch as it seems.
We feel like we have tried almost anything, and also done a lot of reasearch, in possible explenations and solutions to prevent this kind of thing. Some things we have tried also comes similar problems discussed here on stackoverflow regarding preventing overfitting methods, described in the description. Also tried to change learning rate and batch size ofc. I would like to thank in advance, and also this if my first time reaching out, so I would like to apologize for any confusion my description might cause.