The Python package platipy has functionality to encode (multiple-valued) label maps.
To install:
pip install -U pip
pip install platipy
Here is a short example:
import SimpleITK as sitk
from platipy.imaging.label.utils import binary_encode_structure_list
img_label_1 = sitk.ReadImage("img_label_1.nii.gz")
img_label_2 = sitk.ReadImage("img_label_2.nii.gz")
img_label_3 = sitk.ReadImage("img_label_3.nii.gz")
# etc., for however many labels you have
label_list = [img_label_1, img_label_2, img_label_3]
img_encoded = binary_encode_structure_list(label_list)
If you need to use numpy, then you can just convert this SimpleITK image into a 3D numpy array:
arr_encoded = sitk.GetArrayFromImage(img_encoded)
N.B. You can also decode an encoded label map (e.g. the output of your NN) using tools in platipy:
from platipy.imaging.label.utils import binary_decode_image
label_list = binary_decode_image(img_prediction_encoded)
Hope this helps!