Suppose you have the VOC 2012 file paths in a dataset of pairs (image_path: tf.string, image_label_path: tf.string)
.
First, define the colors of each class:
_COLORS = tf.constant([
[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0],
[64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128],
[192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
[0, 64, 128]
], dtype=tf.int32)
And the loading function:
def load_images(image_path, label_path):
image = tf.io.read_file(image_path)
image = tf.io.decode_jpeg(image, channels=3)
image_label = tf.io.read_file(label_path)
image_label = tf.io.decode_png(image_label, channels=3)
image_label = rgb_to_label(image_label)
return image, image_label
def rgb_to_label(segm):
segm = tf.cast(segm[..., tf.newaxis], _COLORS.dtype)
return tf.argmax(tf.reduce_all(segm == tf.transpose(_COLORS), axis=-2), axis=-1)
And just apply the function:
ds = ds.map(load_images, num_parallel_calls=tf.data.AUTOTUNE)
Explanation:
segm[..., tf.newaxis].shape == (H, W, 3, 1)
, while
transpose(colors).shape == (3, 21)
.
Comparing both with ==
broadcasts into a tensor of shape (H, W, 3, 21)
.
If the pixel (h,w)
of a segmentation mask image matches the color of a certain class c
, then all three pixel intensities in the 3rd axis will match. Hence tf.reduce_all(...)
will reduce to a (H, W, 21)
one-hot encoded tensor, containing true
at the corresponding label index and false
everywhere else.
Lastly, tf.argmax(..., axis=-1)
finds the index itself for each pixel (resulting in a (H, W)
image).
It's worth mentioning that argmax will default to 0 if all of them are false. So a pixel containing an unknown color (that's not present in the map _COLORS
) will be assigned to the label 0 (the background).