There are a few resources about this idea, such as a blog post about transferring a ResNet on RGB data to multi-channel images here, and a relevant Colab Notebook. Below is a working example based on those resources:
import numpy as np
import tensorflow as tf
def tile_kernels(kernel, out_channels, batch_dim=-2):
mean_1d = np.mean(kernel, axis=batch_dim).reshape(kernel[:, :, -1:, :].shape)
tiled = np.tile(mean_1d, (out_channels, 1))
return tiled
def reshape_model_input(model_orig, custom_model, input_channels):
conf = custom_model.get_config()
layer_to_modify = conf["layers"][2]["config"]["name"]
layer_names = [conf['layers'][x]['name'] for x in range(len(conf['layers']))]
for layer in model_orig.layers:
if layer.name in layer_names:
if layer.get_weights() != []:
target_layer = custom_model.get_layer(layer.name)
if layer.name == layer_to_modify:
kernels, biases = layer.get_weights()
kernels_extra_channels = np.concatenate((kernels,
tile_kernels(kernels, input_channels - 3)),
axis=-2)
target_layer.set_weights([kernels_extra_channels, biases])
else:
target_layer.set_weights(layer.get_weights())
if __name__ == "__main__":
from tensorflow.keras.applications import ResNet50V2
resnet50 = ResNet50V2(weights='imagenet', include_top=False) # load resnet50 here - can be done differently
config = resnet50.get_config()
img_height = ...
img_width = ...
input_channels = 7
config["layers"][0]["config"]["batch_input_shape"] = (None, img_height, img_width, input_channels) # change the batch input shape to handle the different channel dimensions
custom_resnet = tf.keras.models.Model.from_config(config)
reshape_model_input(resnet50, custom_resnet, input_channels) # modify the custom model by reference
custom_resnet(np.zeros((1, img_width, img_height, input_channels))) # just verifying that predicting with the new shape works in the custom model
This process just iterates over each layer in the original model and sets the corresponding weights in the custom model. To produce the additional n 3 x 3 channels (in your case, n = 4, as you want 7 total channels) for the input, the mean is taken across the 3 RGB dimensions then replicated (as can be seen in the tile_kernels
function). Another aggregation function could be used, such as the max, min, median, etc. If you don't want any of the weights from the original model (as in, not pretraining but just require the architecture), just modifying the original model's configuration and creating a new model from it will create a randomly initialized model:
resnet50 = ...
config = resnet50.get_config()
img_height = ...
img_width = ...
input_channels = ...
config["layers"][0]["config"]["batch_input_shape"] = (None, img_height, img_width, input_channels)
custom_resnet = tf.keras.models.Model.from_config(config)