I'm performing pixel-wise multi-class classification using a U-Net architecture in Keras (TF backend) on many 256x256 images. I've one-hot encoded my outputs using a data generator, making my output 256x256x32 arrays (I have 32 different classes, the classes are represented as pixel values, which are integers from 0-31 in the 256x256 "mask" images).
However, most of the ground truth arrays are empty - in other words, the most common class by far is 0. When I train my U-Net, it seems to overfit to the 0 class. The loss is low and accuracy very high, but only because ~99% of the ground truth is 0, so the U-Net just outputs a bunch of 0s, whereas I only really care about the other 31 classes (as in, how well it can classify the rest of the classes in the ground truth).
Is there a way to "weight" certain classes more than others when calculating the loss function (and if so, would this approach be appropriate)? I'm not sure if it's an intrinsic problem with my data, or a problem with my approach. Here's my U-Net:
def unet(pretrained_weights = None,input_size = (256,256,1)):
inputs = keras.engine.input_layer.Input(input_size)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
#drop4 = Dropout(0.5)(conv4)
drop4 = SpatialDropout2D(0.5)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
#drop5 = Dropout(0.5)(conv5)
drop5 = SpatialDropout2D(0.5)(conv5)
up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
merge6 = concatenate([drop4,up6], axis = 3)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)
up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
merge7 = concatenate([conv3,up7], axis = 3)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = concatenate([conv2,up8], axis = 3)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = concatenate([conv1,up9], axis = 3)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv9 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
conv10 = Conv2D(32, 1, activation = 'softmax')(conv9)
#conv10 = Flatten()(conv10)
#conv10 = Dense(65536, activation = 'softmax')(conv10)
flat10 = Reshape((65536,32))(conv10)
#conv10 = Conv1D(1, 1, activation='linear')(conv10)
model = Model(inputs = inputs, outputs = flat10)
opt = Adam(lr=1e-6,clipvalue=0.01)
model.compile(optimizer = opt, loss = 'categorical_crossentropy', metrics = ['categorical_accuracy'])
#model.compile(optimizer = Adam(lr = 1e-6), loss = 'binary_crossentropy', metrics = ['accuracy'])
#model.compile(optimizer = Adam(lr = 1e-4),
#model.summary()
if(pretrained_weights):
model.load_weights(pretrained_weights)
return model
Please let me know if more information is needed to diagnose the problem.