1

According to Pytorch's documentation on binary_cross_entropy_with_logits, they are described as:

weight

weight (Tensor, optional) – a manual rescaling weight if provided it’s repeated to match input tensor shape

pos_weight

pos_weight (Tensor, optional) – a weight of positive examples. Must be a vector with length equal to the number of classes.

What are their differences? The explanation is quite vague. If I understands correctly, weight is individual weight for each pixel (class), wheres pos_weight is the weight for everything that's not background (negative pixel/zero)?

What if I set both parameters? For example:

import torch 

preds = torch.randn(4, 100, 50, 50)
target = torch.zeros((4, 100, 50, 50))
target[:, :, 10:20, 10:20] = 1

pos_weight = target * 100
pos_weight[pos_weight < 100] = 1
weight = target * 100
weight[weight < 100] = 1

loss1 = binary_cross_entropy_with_logits(preds, target, pos_weight=pos_weight, weight=weight)
loss2 = binary_cross_entropy_with_logits(preds, target, pos_weight=pos_weight)
loss3 = binary_cross_entropy_with_logits(preds, target, weight=weight)

loss1, loss2, and loss3, which one is the correct usage?

On the same subject, I was reading a paper that said:

To deal with the unbalanced negative and positive data, we dilate each keypoint by 10 pixels and use weighted cross-entropy loss. The weight for each keypoint is set to 100 while for non-keypoint pixels it is set to 1.

which one is the correct usage if according to the paper?

Thanks in advance for any explanation!

papillon
  • 21
  • 5

1 Answers1

1

The pos_weight parameter allows you to balance the positive example thus controlling the tradeoff between recall and precision (see also). A detailed explanation can be found on this thread along with the explicit math expression. On the other hand, weight allows to weigh the different elements on a given batch.

Here is a minimal example:

>>> target = torch.ones([10, 64], dtype=torch.float32)
>>> output = torch.full([10, 64], 1.5)

>>> criterion = torch.nn.BCEWithLogitsLoss() # w/o weight
>>> criterion(output, target)
tensor(0.2014) # all batch elements weighted equally

>>> weight = torch.rand(10,1)
>>> criterion = torch.nn.BCEWithLogitsLoss(weight=weight) # w/ weight
>>> criterion(output, target)
tensor(0.0908) # per element weighting

Which is identical to doing:

>>> criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
>>> torch.mean(criterion(output, target)*weight)
tensor(0.0908)
Ivan
  • 34,531
  • 8
  • 55
  • 100
  • If `pos_weight` is to deal with the positive sample, and `weight` for element, let's say I have 100 channel for each output (each channel represent a keypoint), in each channel there's only two classes (background and a single object keypoint). For the channels, they are highly imbalanced, that means around 50-80 percent (of all 100 channels) of channel consist only of background (object is not fully in frame). And for the keypoint, they are too highly imbalanced (only 10 pixels is considered keypoint for the whole frame) ... (next comment) – papillon Apr 01 '22 at 03:22
  • How do you propose I tackle such issue? Would make use of both `pos_weight` and `weight` solve this (like the `loss1`)? – papillon Apr 01 '22 at 03:22
  • That means for any channel that contains only background (keypoint not in frame), the `pos_weight` is 1, and 100 if there's any keypoint in it. While for the keypoint itself, the `weight` is 100 for keypoint pixel, and `1` for background pixel. Does that make sense? – papillon Apr 01 '22 at 03:29