I am trying to do a multilabel classfication problem, which has an imabalnced dataset. The total number of samples is 1130, out of the 1130 samples, the first class occur in 913 of them. The second class 215 times and the third one 423 times.
In the model architecture, I have 3 output nodes, and have applied sigmoid activation.
input_tensor = Input(shape=(256, 256, 3))
base_model = VGG16(input_tensor=input_tensor,weights='imagenet',pooling=None, include_top=False)
#base_model.summary()
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = tf.math.reduce_max(x,axis=0,keepdims=True)
x = Dense(512,activation='relu')(x)
output_1 = Dense(3, activation='sigmoid')(x)
sagittal_model_abn = Model(inputs=base_model.input, outputs=output_1)
for layer in base_model.layers:
layer.trainable = True
I am using binary cross entropy loss, which I calculate usng this function. I am using weighted loss to deal with the imbalance.
if y_true[0]==1:
loss_abn = -1*K.log(y_pred[0][0])*cwb[0][1]
elif y_true[0]==0:
loss_abn = -1*K.log(1-y_pred[0][0])*cwb[0][0]
if y_true[1]==1:
loss_acl = -1*K.log(y_pred[0][1])*cwb[1][1]
elif y_true[1]==0:
loss_acl = -1*K.log(1-y_pred[0][1])*cwb[1][0]
if y_true[2]==1:
loss_men = -1*K.log(y_pred[0][2])*cwb[2][1]
elif y_true[2]==0:
loss_men = -1*K.log(1-y_pred[0][2])*cwb[2][0]
loss_value_ds = loss_abn + loss_acl + loss_men
cwb
contains the class weights.
y_true
is the ground truth labels having length 3.
y_pred
is a numpy array with shape (1,3)
I weight the classes individually as occurrence and non-occurrence of the classes.
Like, if the label is 1, I count it as an occurrence and if it is 0, then it is a non-occurrence.
So, the first class's label 1 occurs 913 times out of 1130
So the class weight of label 1 for the first class is 1130/913 which is about 1.23 and the weight of the label 0 for the first class is 1130/(1130-913)
When I train the model, the accuracy oscillates (or stays almost same), and the loss decreases.
And I am getting predictions like this for every sample
[[0.51018655 0.5010625 0.50482965]]
The prediction values are in the range 0.49 - 0.51 in every iteration for all the classes
Tried changing the number of nodes in the FC layer but it still behaves the same way.
Can anyone help?
Does using tf.math,reduce_max
cause the problem? Should using @tf.function
to do the operation that I am doing using tf.math.reduce_max
be beneficial?
NOTE:
I am weighting the labels 1 and 0 for each class separately.
cwb = {0: {0: 5.207373271889401, 1: 1.2376779846659365},
1: {0: 1.2255965292841648, 1: 5.4326923076923075},
2: {0: 1.5416098226466575, 1: 2.8463476070528966}}
EDIT:
The results when I train using model.fit()
.
Epoch 1/20
1130/1130 [==============================] - 1383s 1s/step - loss: 4.1638 - binary_accuracy: 0.4558 - val_loss: 5.0439 - val_binary_accuracy: 0.3944
Epoch 2/20
1130/1130 [==============================] - 1397s 1s/step - loss: 4.1608 - binary_accuracy: 0.4165 - val_loss: 5.0526 - val_binary_accuracy: 0.5194
Epoch 3/20
1130/1130 [==============================] - 1402s 1s/step - loss: 4.1608 - binary_accuracy: 0.4814 - val_loss: 5.1469 - val_binary_accuracy: 0.6361
Epoch 4/20
1130/1130 [==============================] - 1407s 1s/step - loss: 4.1722 - binary_accuracy: 0.4472 - val_loss: 5.0501 - val_binary_accuracy: 0.5583
Epoch 5/20
1130/1130 [==============================] - 1397s 1s/step - loss: 4.1591 - binary_accuracy: 0.4991 - val_loss: 5.0521 - val_binary_accuracy: 0.6028
Epoch 6/20
1130/1130 [==============================] - 1375s 1s/step - loss: 4.1596 - binary_accuracy: 0.5431 - val_loss: 5.0515 - val_binary_accuracy: 0.5917
Epoch 7/20
1130/1130 [==============================] - 1370s 1s/step - loss: 4.1595 - binary_accuracy: 0.4962 - val_loss: 5.0526 - val_binary_accuracy: 0.6000
Epoch 8/20
1130/1130 [==============================] - 1387s 1s/step - loss: 4.1591 - binary_accuracy: 0.5316 - val_loss: 5.0523 - val_binary_accuracy: 0.6028
Epoch 9/20
1130/1130 [==============================] - 1391s 1s/step - loss: 4.1590 - binary_accuracy: 0.4909 - val_loss: 5.0521 - val_binary_accuracy: 0.6028
Epoch 10/20
1130/1130 [==============================] - 1400s 1s/step - loss: 4.1590 - binary_accuracy: 0.5369 - val_loss: 5.0519 - val_binary_accuracy: 0.6028
Epoch 11/20
1130/1130 [==============================] - 1397s 1s/step - loss: 4.1590 - binary_accuracy: 0.4808 - val_loss: 5.0519 - val_binary_accuracy: 0.6028
Epoch 12/20
1130/1130 [==============================] - 1394s 1s/step - loss: 4.1590 - binary_accuracy: 0.5469 - val_loss: 5.0522 - val_binary_accuracy: 0.6028