2

Using the smooth dice function as proposed in the V-Net paper:

formula

encoded for multiple classes in pytorch with a smooth added so that it is always divisable:

class DiceLoss(torch.nn.Module):
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        assert inputs.shape == targets.shape, f"Shapes don't match {inputs.shape} != {targets.shape}"
        inputs = inputs[:,1:]                                                       # skip background class
        targets = targets[:,1:]                                                     # skip background class
        axes = tuple(range(2, len(inputs.shape)))                                   # sum over elements per sample and per class
        intersection = torch.sum(inputs * targets, axes)
        addition = torch.sum(torch.square(inputs) + torch.square(targets), axes)
        return 1 - torch.mean((2 * intersection + self.smooth) / (addition + self.smooth))

The target is volumetric sample patches of size 52^3 = 140.608. The samples are non overlapping from a volume with 97% background, 2% liver and 1% tumor. As a result most patches will be completely background. The dice loss is bound between 0 and 1.

Other questions set the smoothness to 1.0 or much lower 1e-7

Suppose a patch is fully background (as are most) and the prediction is 0.01 (=1%) liver and 0.01 tumor for all elements. Setting the smoothness to 1.0 will still result in

1 - (2 x 0 + 1) / (0 + 140.608 x 0.0001 + 1) = 0.9336

Close to the maximum loss even though the prediction is accurate.

This forces the network to decrease the likelihood of a liver or tumor to 0.001 or lower. What is the commonly used as a smoothness setting for the dice loss?

blanNL
  • 360
  • 1
  • 11

0 Answers0