I'm trying to fit a Keras (TF 2.3.1) model for image classification with multiple binary labels as output. The model consists of an Xception CNN + attention layer + dense classifier, and hitting an error on some TPUs only:
UnimplementedError: {{function_node __inference_train_function_644557}} Compilation failure: Dynamic Spatial Convolution is not supported
. This fails on Kaggle TPUs but not on Colab - tested both on TF version 2.3.1.
I was looking here but the suggested solution implies that the image dimensions are not set, which is not the case here. train_df
is of type <PrefetchDataset shapes: ((None, 750, 750, 3), (None, 11)), types: (tf.float32, tf.int64)>
so each image has size 750x750x3. Each layer has a defined output shape per the below model summary, so the layers that follow them should infer their input shape correctly.
From the error, it seems that the problem is on the layer defined by attn_layer = LocallyConnected2D(...
. Passing implementation = 2
is a workaround which lets training complete, but this is not suitable for large models (see LocallyConnected2D documentation)
Modelling code:
import tensorflow as tf
from tensorflow.keras import models, layers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.applications import Xception
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten, Input, Conv2D, multiply, LocallyConnected2D, Lambda, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import mean_absolute_error
def create_model():
input_shape = (TARGET_SIZE, TARGET_SIZE, 3)
in_lay = Input(input_shape)
conv_base = Xception(include_top = False, weights = 'imagenet', input_shape = input_shape)
pt_features = conv_base(in_lay)
bn_features = BatchNormalization()(pt_features)
# here we do an attention mechanism to turn pixels in the GAP on an off
attn_layer = Conv2D(64, kernel_size = (1,1), padding = 'same', activation = 'relu')(bn_features)
attn_layer = Conv2D(16, kernel_size = (1,1), padding = 'same', activation = 'relu')(attn_layer)
attn_layer = LocallyConnected2D(1, kernel_size = (1,1), padding = 'valid', activation = 'sigmoid')(attn_layer)
# fan it out to all of the channels
pt_depth = conv_base.get_output_shape_at(0)[-1]
up_c2_w = np.ones((1, 1, 1, pt_depth))
up_c2 = Conv2D(pt_depth, kernel_size = (1,1), padding = 'same',
activation = 'linear', use_bias = False, weights = [up_c2_w])
up_c2.trainable = False
attn_layer = up_c2(attn_layer)
mask_features = multiply([attn_layer, bn_features])
gap_features = GlobalAveragePooling2D()(mask_features)
gap_mask = GlobalAveragePooling2D()(attn_layer)
# to account for missing values from the attention model
gap = Lambda(lambda x: x[0]/x[1], name = 'RescaleGAP')([gap_features, gap_mask])
gap_dr = Dropout(0.5)(gap)
dr_steps = Dropout(0.25)(Dense(1024, activation = 'elu')(gap_dr))
out_layer = Dense(11, activation = 'sigmoid')(dr_steps)
model = Model(inputs = [in_lay], outputs = [out_layer])
model.compile(optimizer = Adam(lr = 0.002), loss = 'binary_crossentropy', metrics = ["AUC"])
return model
with tpu_strategy.scope():
model = create_model()
model.summary()
history = model.fit(
train_df,
epochs = EPOCHS,
steps_per_epoch = STEPS_PER_EPOCH,
validation_data = valid_df,
validation_steps = VALIDATION_STEPS
)
The resulting model summary:
Model: "model_8"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_19 (InputLayer) [(None, 750, 750, 3) 0
__________________________________________________________________________________________________
xception (Model) (None, 24, 24, 2048) 20861480 input_19[0][0]
__________________________________________________________________________________________________
batch_normalization_49 (BatchNo (None, 24, 24, 2048) 8192 xception[1][0]
__________________________________________________________________________________________________
conv2d_67 (Conv2D) (None, 24, 24, 64) 131136 batch_normalization_49[0][0]
__________________________________________________________________________________________________
conv2d_68 (Conv2D) (None, 24, 24, 16) 1040 conv2d_67[0][0]
__________________________________________________________________________________________________
locally_connected2d_9 (LocallyC (None, 24, 24, 1) 9792 conv2d_68[0][0]
__________________________________________________________________________________________________
conv2d_69 (Conv2D) (None, 24, 24, 2048) 2048 locally_connected2d_9[0][0]
__________________________________________________________________________________________________
multiply_9 (Multiply) (None, 24, 24, 2048) 0 conv2d_69[0][0]
batch_normalization_49[0][0]
__________________________________________________________________________________________________
global_average_pooling2d_23 (Gl (None, 2048) 0 multiply_9[0][0]
__________________________________________________________________________________________________
global_average_pooling2d_24 (Gl (None, 2048) 0 conv2d_69[0][0]
__________________________________________________________________________________________________
RescaleGAP (Lambda) (None, 2048) 0 global_average_pooling2d_23[0][0]
global_average_pooling2d_24[0][0]
__________________________________________________________________________________________________
dropout_18 (Dropout) (None, 2048) 0 RescaleGAP[0][0]
__________________________________________________________________________________________________
dense_17 (Dense) (None, 1024) 2098176 dropout_18[0][0]
__________________________________________________________________________________________________
dropout_19 (Dropout) (None, 1024) 0 dense_17[0][0]
__________________________________________________________________________________________________
dense_18 (Dense) (None, 11) 11275 dropout_19[0][0]
==================================================================================================
Total params: 23,123,139
Trainable params: 23,062,467
Non-trainable params: 60,672
__________________________________________________________________________________________________
Full stacktrace + error message:
---------------------------------------------------------------------------
UnimplementedError Traceback (most recent call last)
<ipython-input-53-5130a0bcf331> in <module>
19 validation_data = valid_df,
20 validation_steps = VALIDATION_STEPS,
---> 21 callbacks = [model_save, early_stop, reduce_lr]
22 )
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside `run_distribute_coordinator` already.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
853 context.async_wait()
854 logs = tmp_logs # No error, now safe to assign to logs.
--> 855 callbacks.on_train_batch_end(step, logs)
856 epoch_logs = copy.copy(logs)
857
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_train_batch_end(self, batch, logs)
387 """
388 if self._should_call_train_batch_hooks:
--> 389 logs = self._process_logs(logs)
390 self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
391
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in _process_logs(self, logs)
263 """Turns tensors into numpy arrays or Python scalars."""
264 if logs:
--> 265 return tf_utils.to_numpy_or_python_type(logs)
266 return {}
267
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in to_numpy_or_python_type(tensors)
521 return t # Don't turn ragged or sparse tensors to NumPy.
522
--> 523 return nest.map_structure(_to_single_numpy_or_python_type, tensors)
524
/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in map_structure(func, *structure, **kwargs)
615
616 return pack_sequence_as(
--> 617 structure[0], [func(*x) for x in entries],
618 expand_composites=expand_composites)
619
/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in <listcomp>(.0)
615
616 return pack_sequence_as(
--> 617 structure[0], [func(*x) for x in entries],
618 expand_composites=expand_composites)
619
/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in _to_single_numpy_or_python_type(t)
517 def _to_single_numpy_or_python_type(t):
518 if isinstance(t, ops.Tensor):
--> 519 x = t.numpy()
520 return x.item() if np.ndim(x) == 0 else x
521 return t # Don't turn ragged or sparse tensors to NumPy.
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in numpy(self)
959 """
960 # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
--> 961 maybe_arr = self._numpy() # pylint: disable=protected-access
962 return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
963
/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _numpy(self)
927 return self._numpy_internal()
928 except core._NotOkStatusException as e:
--> 929 six.raise_from(core._status_to_exception(e.code, e.message), None)
930
931 @property
/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value, from_value)
UnimplementedError: {{function_node __inference_train_function_644557}} Compilation failure: Dynamic Spatial Convolution is not supported: %convolution.30660 = f32[<=8,24,24,2048]{3,2,1,0} convolution(f32[<=8,24,24,1]{3,2,1,0} %add.30633, f32[1,1,1,2048]{3,2,1,0} %get-tuple-element.354), window={size=1x1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="model_8/conv2d_69/Conv2D"}
TPU compilation failed
[[{{node tpu_compile_succeeded_assert/_17367812259898276239/_5}}]]