I have a semantic segmentation task of an image with only 2 classes: 1 and 0.
The target is of shape NxHxW
. Each sample on the batch is a binary image with size HxW
with pixel values of 0 or 1.
The output of the net is Nx2xHxW
as we have 2 classes and the net output is the probability per class. So basically each channel is the probability of the class.
How can I use Pytorch's binary_cross_entropy()
and binary_cross_entropy_with_logits()
in this case?
In PyTorch, for the multi class case, I can use the Cross Entropy Loss. As described in Pytorch semantic segmentation loss function it can be used with the target including only the class index.
Is there something equivalent for the binary case?
Both binary_cross_entropy()
and binary_cross_entropy_with_logits()
requires both the input and the target to have the exact shape.
What would be the recommended way to handle this, performance wise?