I think the easiest solution here, if I have understood your question correctly, is to separate your data sets based on your value of i
.
So take your X_train
and split it into X_train_1
, X_train_2
, etc. Likewise with your X_test
, split it into X_test_1
, X_test_2
, etc.
from keras.models import Sequential, Model
from keras.layers import *
from keras.utils import plot_model
Then set up separate models:
model1 = Sequential()
model1.add(Conv2D(32, kernel_size=(3,3), activation="relu", input_shape=(24,24,3)))
model1.add(MaxPooling2D(pool_size=(2,2)))
model1.add(Flatten())
model1.add(Dropout(0.5))
model1.add(Dense(512, activation = "relu"))
model2 = Sequential()
model2.add(Conv2D(32, kernel_size=(3,3), activation="relu", input_shape=(24,24,3)))
model2.add(MaxPooling2D(pool_size=(2,2)))
model2.add(Flatten())
model2.add(Dropout(0.5))
model2.add(Dense(512, activation = "relu"))
You will want to use the functional API to combine them. I have used Concatenate()
, other options are shown in the documentation here.
outputs = Concatenate()([model1.output,model2.output])
outputs = Dense(256, activation='relu')(outputs)
outputs = Dropout(.5)(outputs)
outputs = Dense(5, activation='softmax')(outputs)
Now configure your final model, specifying the inputs and outputs:
model = Model(inputs=[model1.inputs, model2.inputs], outputs=outputs)
Checking model.summary()
, you can see how each layer is connected:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
conv2d_input (InputLayer) [(None, 24, 24, 3)] 0
__________________________________________________________________________________________________
conv2d_1_input (InputLayer) [(None, 24, 24, 3)] 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 22, 22, 32) 896 conv2d_input[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 22, 22, 32) 896 conv2d_1_input[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 11, 11, 32) 0 conv2d[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 11, 11, 32) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
flatten (Flatten) (None, 3872) 0 max_pooling2d[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 3872) 0 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 3872) 0 flatten[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 3872) 0 flatten_1[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 512) 1982976 dropout[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 512) 1982976 dropout_1[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 1024) 0 dense[0][0]
dense_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 256) 262400 concatenate[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 256) 0 dense_2[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 5) 1285 dropout_2[0][0]
==================================================================================================
Total params: 4,231,429
Trainable params: 4,231,429
Non-trainable params: 0
__________________________________________________________________________________________________
But it's easier to visualise the model with plot_model(model, to_file='image.png', show_shapes=True)
:

Then for training the model, you will need to feed in the different inputs, not forgetting your test (or validation) data:
model.fit([X_train_1, X_train_2], y_train, validation_data = ([X_test_1, X_test_2], y_val), ...)
NB: The sub-models (here model1
, model2
, etc.) don't have to have the same structure. They can have different sized layers, different numbers of layers, and different types of layers. This is also how you can include data sets with different types of features in your model.