2

When running the model.fit function an error is thrown. The main question is, what does this error mean? The code is run on a TPU V3-8 and uses Google cloud for data retrieval. I did try to look up the error on the web, however I could not find a single case of someone else getting this error.

model.fit(
    dataset,
    steps_per_epoch = N_IMGS // BATCH_SIZE,
    epochs = EPOCHS,
)

Throws the error

InvalidArgumentError: {{function_node __inference_train_function_528542}} Compilation failure: Depth of output must be a multiple of the number of groups: 3 vs 2
     [[{{node sequential/conv2d/Conv2D}}]]
    TPU compilation failed
     [[tpu_compile_succeeded_assert/_15965336225898828069/_5]]

The error message is not clear to me, what exactly is going wrong? The following model is used.

def get_model():
    # reset to free memory and training variables
    tf.keras.backend.clear_session()
        
    with strategy.scope():    
        net = efn.EfficientNetB0(include_top=False, weights='noisy-student', input_shape=(HEIGHT, WIDTH, 3))

        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(3, (3, 3), padding='same', input_shape=(HEIGHT, WIDTH, 1)),
            net,
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dropout(0.25),
            tf.keras.layers.Dense(N_LABELS, activation='softmax', dtype='float32'),
        ])
    
    model.compile(optimizer=tf.keras.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    
    return model

model = get_model()
tf.keras.utils.plot_model(model, 'model.png', show_shapes=True)

Summary of CNN model used

The dataset gives the following output

for images, labels in dataset.take(1): # only take first element of dataset
    print(f'images.shape: {images.shape}, images.dtype: {images.dtype}, labels.shape: {labels.shape}, labels.dtype: {labels.dtype}')

images.shape: (64, 224, 400, 1), images.dtype: <dtype: 'float32'>, labels.shape: (64,), labels.dtype: <dtype: 'int32'>

Mark wijkhuizen
  • 373
  • 3
  • 10

0 Answers0