When running the model.fit
function an error is thrown. The main question is, what does this error mean? The code is run on a TPU V3-8 and uses Google cloud for data retrieval. I did try to look up the error on the web, however I could not find a single case of someone else getting this error.
model.fit(
dataset,
steps_per_epoch = N_IMGS // BATCH_SIZE,
epochs = EPOCHS,
)
Throws the error
InvalidArgumentError: {{function_node __inference_train_function_528542}} Compilation failure: Depth of output must be a multiple of the number of groups: 3 vs 2
[[{{node sequential/conv2d/Conv2D}}]]
TPU compilation failed
[[tpu_compile_succeeded_assert/_15965336225898828069/_5]]
The error message is not clear to me, what exactly is going wrong? The following model is used.
def get_model():
# reset to free memory and training variables
tf.keras.backend.clear_session()
with strategy.scope():
net = efn.EfficientNetB0(include_top=False, weights='noisy-student', input_shape=(HEIGHT, WIDTH, 3))
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(3, (3, 3), padding='same', input_shape=(HEIGHT, WIDTH, 1)),
net,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Dense(N_LABELS, activation='softmax', dtype='float32'),
])
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model
model = get_model()
tf.keras.utils.plot_model(model, 'model.png', show_shapes=True)
The dataset gives the following output
for images, labels in dataset.take(1): # only take first element of dataset
print(f'images.shape: {images.shape}, images.dtype: {images.dtype}, labels.shape: {labels.shape}, labels.dtype: {labels.dtype}')
images.shape: (64, 224, 400, 1), images.dtype: <dtype: 'float32'>, labels.shape: (64,), labels.dtype: <dtype: 'int32'>