2

I'm doing semantic segmentation in keras and tried to modify the categorical_crossentropy loss so that the loss is class-weighted.

Here is my code:

def class_weighted_categorical_crossentropy(output, target, from_logits=False):
"""Categorical crossentropy between an output tensor and a target tensor.

parameter = TrainingParameters()
   # create ones array with shape of target tensor
   # multiply class weight array with inverse class_accuracies for each label
class_weights = tf.convert_to_tensor(parameter.class_weights, dtype=floatx())
   # weight targets with class weights and create pattern with which loss can be multiplied
class_weights_pattern = tf.multiply(target, class_weights)
class_weights_pattern = tf.reduce_sum(class_weights_pattern, reduction_indices=len(class_weights_pattern.get_shape())-1)#, keep_dims=True)
if not from_logits:
    # scale preds so that the class probas of each sample sum to 1
    output /= tf.reduce_sum(output,
                            reduction_indices=len(output.get_shape()) - 1,
                            keep_dims=True)
    # manual computation of crossentropy
    epsilon = _to_tensor(_EPSILON, output.dtype.base_dtype)
    output = tf.clip_by_value(output, epsilon, 1. - epsilon)
    loss = - tf.reduce_sum(target * tf.log(output), reduction_indices=len(output.get_shape()) - 1)
    return tf.multiply(loss, class_weights_pattern)
else:
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=target, logits=output)
    return tf.multiply(loss, class_weights_pattern)

I only changed in the end that the loss is multiplied with the class_weights pattern. The class_weights_pattern contains for each pixel the corresponding class weight and thus should weight the normal categorical_crossentropy loss. However if i train my model with the modified loss, the results are way worse than if i only use the keras categorical_crossentropy loss. Even if i set all class-weights to 1, so that my class_weighted_categorical_crossentropy loss should be exactly the same than the categorical_crossentropy loss from keras, the results are worse. I printed the both losses already with a few sample images and the losses are exactly the same.

Can anybody help me? Why does it not work? Thanks in advance!

nubs91
  • 21
  • 2

0 Answers0