I am trying to use the Keras implementation of DeepLabV3+ for a binary segmentation task using a custom number of channels (ex. 4 instead of 3).
What I have modified so far:
- I know that I cannot use the imagenet weights (trained for 3 channels), so I have changed the weights to None (random).
- I have changed the input shape for a 4-channel implementation: keras.Input(shape=(image_size, image_size, 4))
- Since this is a binary segmentation task, I changed the output with a sigmoid activation: model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same", activation='sigmoid')(x)
Here is the original code if the link is unavailable:
def convolution_block(
block_input,
num_filters=256,
kernel_size=3,
dilation_rate=1,
padding="same",
use_bias=False,
):
x = layers.Conv2D(
num_filters,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
padding="same",
use_bias=use_bias,
kernel_initializer=keras.initializers.HeNormal(),
)(block_input)
x = layers.BatchNormalization()(x)
return tf.nn.relu(x)
def DilatedSpatialPyramidPooling(dspp_input):
dims = dspp_input.shape
x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
x = convolution_block(x, kernel_size=1, use_bias=True)
out_pool = layers.UpSampling2D(
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), interpolation="bilinear",
)(x)
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
output = convolution_block(x, kernel_size=1)
return output
def DeeplabV3Plus(image_size, num_classes):
model_input = keras.Input(shape=(image_size, image_size, 3))
resnet50 = keras.applications.ResNet50(
weights="imagenet", include_top=False, input_tensor=model_input
)
x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x)
input_a = layers.UpSampling2D(
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
interpolation="bilinear",
)(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = layers.Concatenate(axis=-1)([input_a, input_b])
x = convolution_block(x)
x = convolution_block(x)
x = layers.UpSampling2D(
size=(image_size // x.shape[1], image_size // x.shape[2]),
interpolation="bilinear",
)(x)
model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
return keras.Model(inputs=model_input, outputs=model_output)
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES)
Here is the code with my changes for 4-channel input and binary output:
def DeeplabV3Plus(image_size, num_classes=1):
model_input = keras.Input(shape=(image_size, image_size, 4))
resnet50 = keras.applications.ResNet50(
weights=None, include_top=False, input_tensor=model_input
)
x = resnet50.get_layer("conv4_block6_2_relu").output
x = DilatedSpatialPyramidPooling(x)
input_a = layers.UpSampling2D(
size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]),
interpolation="bilinear",
)(x)
input_b = resnet50.get_layer("conv2_block3_2_relu").output
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
x = layers.Concatenate(axis=-1)([input_a, input_b])
x = convolution_block(x)
x = convolution_block(x)
x = layers.UpSampling2D(
size=(image_size // x.shape[1], image_size // x.shape[2]),
interpolation="bilinear",
)(x)
model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same", activation='sigmoid')(x)
return keras.Model(inputs=model_input, outputs=model_output)
model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES)
However, I am getting the following errors:
InvalidArgumentError: Graph execution error:
Detected at node 'assert_greater_equal/Assert/AssertGuard/Assert'
and
Node: 'assert_greater_equal/Assert/AssertGuard/Assert'
2 root error(s) found.
(0) INVALID_ARGUMENT: assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (model_5/conv2d_59/Sigmoid:0) = ] [[[[nan][nan][nan]]]...] [y (Cast_8/x:0) = ] [0]
[[{{node assert_greater_equal/Assert/AssertGuard/Assert}}]]
[[assert_greater_equal_1/Assert/AssertGuard/pivot_f/_23/_65]]
(1) INVALID_ARGUMENT: assertion failed: [predictions must be >= 0] [Condition x >= y did not hold element-wise:] [x (model_5/conv2d_59/Sigmoid:0) = ] [[[[nan][nan][nan]]]...] [y (Cast_8/x:0) = ] [0]
[[{{node assert_greater_equal/Assert/AssertGuard/Assert}}]]
Note: My images and masks are scaled to 0-1.