13

I'm trying to make a network that outputs a depth map, and semantic segmentation data separately.

In order to train the network, I'd like to use categorical cross entropy for the segmentation branch, and mean squared error for the branch that outputs the depth map.

I couldn't find any info on implementing the two loss functions for each branches in the Keras documentation for the Functional API.

Is it possible for me to use these loss functions simultaneously during training, or would it be better for me to train the different branches separately?

Kemal Ficici
  • 153
  • 1
  • 6

1 Answers1

19

From the documentation of Model.compile:

loss: String (name of objective function) or objective function. See losses. If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses. The loss value that will be minimized by the model will then be the sum of all individual losses.

If your output is named, you can use a dictionary mapping the names to the corresponding losses:

x = Input((10,))
out1 = Dense(10, activation='softmax', name='segmentation')(x)
out2 = Dense(10, name='depth')(x)
model = Model(x, [out1, out2])
model.compile(loss={'segmentation': 'categorical_crossentropy', 'depth': 'mse'},
              optimizer='adam')

Otherwise, use a list of losses (in the same order as the corresponding model outputs).

x = Input((10,))
out1 = Dense(10, activation='softmax')(x)
out2 = Dense(10)(x)
model = Model(x, [out1, out2])
model.compile(loss=['categorical_crossentropy', 'mse'], optimizer='adam')
Yu-Yang
  • 14,539
  • 2
  • 55
  • 62
  • What about using different loss function for validation set. I mean I am using weight loss function for training set that have different number of example per each class ( it is unbalanced). But for validation, I don't want to use the weighted loss., because it has united number of example per each class. So can I pass different loss function for validation set? – W. Sam Aug 30 '18 at 22:41
  • @W.Sam How did you implement it? I think the built-in `class_weight` is only applied to the training set. – Yu-Yang Aug 31 '18 at 15:26
  • Is that mean the validation loss reported at each epoch is wrong value because it is based on training loss function?I implement it by building custom loss function and pass this custom loss function to losses. One possible solution is getting access to the flag of "is training", the same flag related with dropout and batch normalization since both operation have different behavior in training and validation. But I am not sure where I can find it. If I can find this flag I can control which loss to use in each phase. – W. Sam Sep 01 '18 at 16:35
  • 1
    If you're not using the `class_weight` argument or sample weights, then yes, your validation loss is computed by the same function as training loss. Maybe you can use `K.in_train_phase()` function in your custom loss. – Yu-Yang Sep 01 '18 at 17:39
  • Thank you very much Yu-Yang. That' what I am looking for. Thanks – W. Sam Sep 01 '18 at 18:04
  • if you can post your answer on my question here, I will consider it as accepted answer. And if you have time with simple code explanation , I will appreciate that. https://stackoverflow.com/questions/52107555/different-loss-function-for-validation-set-in-keras – W. Sam Sep 01 '18 at 18:22
  • Sure. I'll post an answer with a basic example. – Yu-Yang Sep 01 '18 at 18:30
  • Basically what I did like : loss= K.in_train_phase(train_loss, val_loss). I am not sure if this is right. But I will try to implement it and see. – W. Sam Sep 01 '18 at 18:40
  • Yes that's basically how it should be used. I've posted an answer on the question with an example and some details. Please see if that works for your model. – Yu-Yang Sep 01 '18 at 18:46