Let's use the following example for a semantic segmentation problem using TorchMetrics, where we predict tensors of shape (batch_size, classes, height, width)
:
# shape: (1, 3, 2, 2) => (batch_size, classes, height, width)
mask_multiclass_pred = torch.tensor(
[[
[
# predictions for first class per pixel
[0.85, 0.4],
[0.4, 0.3],
],
[
# predictions for second class per pixel
[0, 0.8],
[0, 1],
],
[
# predictions for third class per pixel
[0.8, 0.6],
[0.7, 0.3],
]
]],
dtype=torch.float32
)
Obviously, if we reduce this to the actual predicted classes as an index tensor:
reduced_pred = torch.argmax(mask_multiclass_pred, dim=1)
reduced_pred = torch.where(torch.amax(mask_multiclass_pred, dim=1) >= 0.5, reduced_pred, -1)
We get:
# shape: (1, 2, 2) => (batch_size, height, width)
tensor([[[0, 1],
[2, 1]]])
...for the predictions.
Let's supposed the following would be our ground truth for the labels, in shape (batch_size, height, width)
the MulticlassAccuracy documentation suggests the targets should be (N, ...)
, thus only batch_size
and ...
-> extra dimensions, which in semantic segmentation is height & width:
# shape: (1, 2, 2) => (batch_size, height, width)
# as suggested by TorchMetrics targets should be (N, ...) where ... is the extra dimensions, in this case 2D => class per pixel
mask_multiclass_gt = torch.tensor(
[
[
# class 0, 1, or 2 per pixel => (2, 2) shape for mask
[0, 1],
[0, 2],
],
],
dtype=torch.int
)
Now, if we calculate the MulticlassAccuracy:
seg_acc_cls = MulticlassAccuracy(num_classes=3, top_k=1, average="none", multidim_average="global")
seg_acc_cls(mask_multiclass_pred, mask_multiclass_gt)
We get the following result:
# shape (3,) => one accuracy per class (3 classes)
tensor([0.5000, 1.0000, 0.0000])
Why is this the output?
For example, shouldn't the first class be 0.75 instead of 0.5? Because for the default threshold of 0.5 our reduced predictions for the first class would be:
[0, 1] => [True, False]
[2, 1] => [False, False]
And obviously then we have 1 TP, 2 TN, and 1 FN. So we should have (1+2)/4?!
Likewise, the second class would be:
[0, 1] => [False, True]
[2, 1] => [False, True]
So again, we have 1 TP, but also 1 FP (lower right), and then 2 TN, which again should be (1 TP + 2TN)/4 = 0.75 and not 1.0.
For the 3rd class we would get these reduced predictions:
[0, 1] => [False, False]
[2, 1] => [True, False]
Which should be 0 TP (only lower right was True), 1 FP (lower left), and 2 TN should be 2/4 => 0.5.