I am using weighted Binary cross entropy Dice loss for a segmentation problem with class imbalance (80 times more black pixels than white pixels) .
def weighted_bce_dice_loss(y_true, y_pred):
y_true = K.cast(y_true, 'float32')
y_pred = K.cast(y_pred, 'float32')
averaged_mask = K.pool2d(
y_true, pool_size=(50, 50), strides=(1, 1), padding='same', pool_mode='avg')
weight = K.ones_like(averaged_mask)
w0 = K.sum(weight)
weight = 5. * K.exp(-5. * K.abs(averaged_mask - 0.5))
w1 = K.sum(weight)
weight *= (w0 / w1)
loss = weighted_bce_loss(y_true, y_pred, weight) + dice_loss(y_true, y_pred)
return loss
Dice coeffecient increased and the loss decreased but at every epoch I am getting a black image as output (all the pixels are labelled black)
Any ideas on why this is happening?