I am quite new to The resnet/unet architecture, and even more to transfer learning.
I have a segmentation problem to solve, and for this I have been recommended to use both the ResNet architecture + transfer learning. I've found an interesting sequential structure that I've copied from github + the resnet50 application.
My questions/remarks are :
- can I couple both this structure + the resnet50 one ?
- Upon using resnet50, I specify my input shape (256, 256, 3). But the summary of resnet50 gives me a total different output shape (None, 8, 8, 2) , and I feel like I've got now way of controlling that
- I guess I haven't quite understood how to couple the code I've found on github and transfer learning, but the aim of this piece of code is too automatically segmentate grayscale pictures. The inputs are 256256 in 3 channels, the output should be 256256* number of classes to segmentate (in this case, 2).
I am quite lost :)
Thanks everyone
here is my code :
filters = 64
base_model = applications.ResNet50(weights='imagenet',
include_top=False,
input_shape=(256, 256, 3),
pooling='none')
base_model.trainable = False
inputs = keras.Input(shape=(256, 256, 3))
main_path = base_model(inputs, training=False)
#ENCODER
to_decoder = []
main_path2=Conv2D(filters=filters, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path)
main_path2=Activation(activation='relu')(main_path2)
main_path2=Conv2D(filters=filters, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path2)
shortcut = Conv2D(filters=filters, kernel_size=(1, 1), strides=(1, 1))(main_path)
main_path = add([shortcut, main_path2])
to_decoder.append(main_path)
#main_path = res_block(main_path, [filters*2, filters*2], [(2, 2), (1, 1)])
main_path2 = Activation(activation='relu')(main_path)
main_path2 = Conv2D(filters=filters*2, kernel_size=(3, 3), padding='same', strides=(2, 2))(main_path2)
main_path2 = Activation(activation='relu')(main_path2)
main_path2 = Conv2D(filters=filters*2, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path2)
shortcut = Conv2D(filters*2, kernel_size=(1, 1), strides=(2,2))(main_path)
main_path = add([shortcut, main_path2])
to_decoder.append(main_path)
#main_path = res_block(main_path, [filters*4, filters*4], [(2, 2), (1, 1)])
main_path2 = Activation(activation='relu')(main_path)
main_path2 = Conv2D(filters=filters*4, kernel_size=(3, 3), padding='same', strides=(2, 2))(main_path2)
main_path2 = Activation(activation='relu')(main_path2)
main_path2 = Conv2D(filters=filters*4, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path2)
shortcut = Conv2D(filters*4, kernel_size=(1, 1), strides=(2,2))(main_path)
main_path = add([shortcut, main_path2])
to_decoder.append(main_path)
#RES BLOCK
#path = res_block(to_decoder[2], [filters*8, filters*8], [(2, 2), (1, 1)])
main_path2 = Activation(activation='relu')(main_path)
main_path2 = Conv2D(filters=filters*8, kernel_size=(3, 3), padding='same', strides=(2, 2))(main_path2)
main_path2 = Activation(activation='relu')(main_path2)
main_path2 = Conv2D(filters=filters*8, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path2)
shortcut = Conv2D(filters*8, kernel_size=(1, 1), strides=(2, 2))(main_path)
main_path = add([shortcut, main_path2])
#DECODER
main_path = keras.layers.Conv2DTranspose(filters*4, (2, 2), strides=(2, 2), padding="same")(main_path)
main_path = concatenate([main_path, to_decoder[2]], axis=3)
#main_path = res_block(main_path, [filters*4, filters*4], [(1, 1), (1, 1)])
main_path2 = Activation(activation='relu')(main_path)
main_path2 = Conv2D(filters=filters*4, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path2)
main_path2 = Activation(activation='relu')(main_path2)
main_path2 = Conv2D(filters=filters*4, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path2)
shortcut = Conv2D(filters*4, kernel_size=(1, 1), strides=(1,1))(main_path)
main_path = add([shortcut, main_path2])
main_path = keras.layers.Conv2DTranspose(filters *2 , (2, 2), strides=(2, 2), padding="same")(main_path)
main_path = concatenate([main_path, to_decoder[1]], axis=3)
#main_path = res_block(main_path, [filters*2, filters*2], [(1, 1), (1, 1)])
main_path2 = Activation(activation='relu')(main_path)
main_path2 = Conv2D(filters=filters*2, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path2)
main_path2 = Activation(activation='relu')(main_path2)
main_path2 = Conv2D(filters=filters*2, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path2)
shortcut = Conv2D(filters*2, kernel_size=(1, 1), strides=(1,1))(main_path)
main_path = add([shortcut, main_path2])
main_path = keras.layers.Conv2DTranspose(filters *1 , (2, 2), strides=(2, 2), padding="same")(main_path)
main_path = concatenate([main_path, to_decoder[0]], axis=3)
#main_path = res_block(main_path, [filters, filters], [(1, 1), (1, 1)])
main_path2 = Activation(activation='relu')(main_path)
main_path2 = Conv2D(filters=filters, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path2)
main_path2 = Activation(activation='relu')(main_path2)
main_path2 = Conv2D(filters=filters, kernel_size=(3, 3), padding='same', strides=(1, 1))(main_path2)
shortcut = Conv2D(filters, kernel_size=(1, 1), strides=(1,1))(main_path)
main_path = add([shortcut, main_path2])
outputs = Conv2D(filters=2, kernel_size=(1, 1), activation='softmax')(main_path)
model = keras.Model(inputs, outputs)
model.summary()
the summary looks like that:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_92 (InputLayer) [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
resnet50 (Functional) (None, 8, 8, 2048) 23587712 input_92[0][0]
__________________________________________________________________________________________________
conv2d_436 (Conv2D) (None, 8, 8, 64) 1179712 resnet50[0][0]
__________________________________________________________________________________________________
activation_241 (Activation) (None, 8, 8, 64) 0 conv2d_436[0][0]
__________________________________________________________________________________________________
conv2d_438 (Conv2D) (None, 8, 8, 64) 131136 resnet50[0][0]
__________________________________________________________________________________________________
conv2d_437 (Conv2D) (None, 8, 8, 64) 36928 activation_241[0][0]
__________________________________________________________________________________________________
add_136 (Add) (None, 8, 8, 64) 0 conv2d_438[0][0]
conv2d_437[0][0]
__________________________________________________________________________________________________
activation_242 (Activation) (None, 8, 8, 64) 0 add_136[0][0]
__________________________________________________________________________________________________
conv2d_439 (Conv2D) (None, 4, 4, 128) 73856 activation_242[0][0]
__________________________________________________________________________________________________
activation_243 (Activation) (None, 4, 4, 128) 0 conv2d_439[0][0]
__________________________________________________________________________________________________
conv2d_441 (Conv2D) (None, 4, 4, 128) 8320 add_136[0][0]
__________________________________________________________________________________________________
conv2d_440 (Conv2D) (None, 4, 4, 128) 147584 activation_243[0][0]
__________________________________________________________________________________________________
add_137 (Add) (None, 4, 4, 128) 0 conv2d_441[0][0]
conv2d_440[0][0]
__________________________________________________________________________________________________
activation_244 (Activation) (None, 4, 4, 128) 0 add_137[0][0]
__________________________________________________________________________________________________
conv2d_442 (Conv2D) (None, 2, 2, 256) 295168 activation_244[0][0]
__________________________________________________________________________________________________
activation_245 (Activation) (None, 2, 2, 256) 0 conv2d_442[0][0]
__________________________________________________________________________________________________
conv2d_444 (Conv2D) (None, 2, 2, 256) 33024 add_137[0][0]
__________________________________________________________________________________________________
conv2d_443 (Conv2D) (None, 2, 2, 256) 590080 activation_245[0][0]
__________________________________________________________________________________________________
add_138 (Add) (None, 2, 2, 256) 0 conv2d_444[0][0]
conv2d_443[0][0]
__________________________________________________________________________________________________
activation_246 (Activation) (None, 2, 2, 256) 0 add_138[0][0]
__________________________________________________________________________________________________
conv2d_445 (Conv2D) (None, 1, 1, 512) 1180160 activation_246[0][0]
__________________________________________________________________________________________________
activation_247 (Activation) (None, 1, 1, 512) 0 conv2d_445[0][0]
__________________________________________________________________________________________________
conv2d_447 (Conv2D) (None, 1, 1, 512) 131584 add_138[0][0]
__________________________________________________________________________________________________
conv2d_446 (Conv2D) (None, 1, 1, 512) 2359808 activation_247[0][0]
__________________________________________________________________________________________________
add_139 (Add) (None, 1, 1, 512) 0 conv2d_447[0][0]
conv2d_446[0][0]
__________________________________________________________________________________________________
conv2d_transpose_43 (Conv2DTran (None, 2, 2, 256) 524544 add_139[0][0]
__________________________________________________________________________________________________
concatenate_44 (Concatenate) (None, 2, 2, 512) 0 conv2d_transpose_43[0][0]
add_138[0][0]
__________________________________________________________________________________________________
activation_248 (Activation) (None, 2, 2, 512) 0 concatenate_44[0][0]
__________________________________________________________________________________________________
conv2d_448 (Conv2D) (None, 2, 2, 256) 1179904 activation_248[0][0]
__________________________________________________________________________________________________
activation_249 (Activation) (None, 2, 2, 256) 0 conv2d_448[0][0]
__________________________________________________________________________________________________
conv2d_450 (Conv2D) (None, 2, 2, 256) 131328 concatenate_44[0][0]
__________________________________________________________________________________________________
conv2d_449 (Conv2D) (None, 2, 2, 256) 590080 activation_249[0][0]
__________________________________________________________________________________________________
add_140 (Add) (None, 2, 2, 256) 0 conv2d_450[0][0]
conv2d_449[0][0]
__________________________________________________________________________________________________
conv2d_transpose_44 (Conv2DTran (None, 4, 4, 128) 131200 add_140[0][0]
__________________________________________________________________________________________________
concatenate_45 (Concatenate) (None, 4, 4, 256) 0 conv2d_transpose_44[0][0]
add_137[0][0]
__________________________________________________________________________________________________
activation_250 (Activation) (None, 4, 4, 256) 0 concatenate_45[0][0]
__________________________________________________________________________________________________
conv2d_451 (Conv2D) (None, 4, 4, 128) 295040 activation_250[0][0]
__________________________________________________________________________________________________
activation_251 (Activation) (None, 4, 4, 128) 0 conv2d_451[0][0]
__________________________________________________________________________________________________
conv2d_453 (Conv2D) (None, 4, 4, 128) 32896 concatenate_45[0][0]
__________________________________________________________________________________________________
conv2d_452 (Conv2D) (None, 4, 4, 128) 147584 activation_251[0][0]
__________________________________________________________________________________________________
add_141 (Add) (None, 4, 4, 128) 0 conv2d_453[0][0]
conv2d_452[0][0]
__________________________________________________________________________________________________
conv2d_transpose_45 (Conv2DTran (None, 8, 8, 64) 32832 add_141[0][0]
__________________________________________________________________________________________________
concatenate_46 (Concatenate) (None, 8, 8, 128) 0 conv2d_transpose_45[0][0]
add_136[0][0]
__________________________________________________________________________________________________
activation_252 (Activation) (None, 8, 8, 128) 0 concatenate_46[0][0]
__________________________________________________________________________________________________
conv2d_454 (Conv2D) (None, 8, 8, 64) 73792 activation_252[0][0]
__________________________________________________________________________________________________
activation_253 (Activation) (None, 8, 8, 64) 0 conv2d_454[0][0]
__________________________________________________________________________________________________
conv2d_456 (Conv2D) (None, 8, 8, 64) 8256 concatenate_46[0][0]
__________________________________________________________________________________________________
conv2d_455 (Conv2D) (None, 8, 8, 64) 36928 activation_253[0][0]
__________________________________________________________________________________________________
add_142 (Add) (None, 8, 8, 64) 0 conv2d_456[0][0]
conv2d_455[0][0]
__________________________________________________________________________________________________
conv2d_457 (Conv2D) (None, 8, 8, 2) 130 add_142[0][0]
==================================================================================================
Total params: 32,939,586
Trainable params: 32,886,466
Non-trainable params: 53,120
__________________________________________________________________________________________________