4

I'm trying to fit a Keras (TF 2.3.1) model for image classification with multiple binary labels as output. The model consists of an Xception CNN + attention layer + dense classifier, and hitting an error on some TPUs only: UnimplementedError: {{function_node __inference_train_function_644557}} Compilation failure: Dynamic Spatial Convolution is not supported. This fails on Kaggle TPUs but not on Colab - tested both on TF version 2.3.1.

I was looking here but the suggested solution implies that the image dimensions are not set, which is not the case here. train_df is of type <PrefetchDataset shapes: ((None, 750, 750, 3), (None, 11)), types: (tf.float32, tf.int64)> so each image has size 750x750x3. Each layer has a defined output shape per the below model summary, so the layers that follow them should infer their input shape correctly.

From the error, it seems that the problem is on the layer defined by attn_layer = LocallyConnected2D(.... Passing implementation = 2 is a workaround which lets training complete, but this is not suitable for large models (see LocallyConnected2D documentation)

Modelling code:

import tensorflow as tf
from tensorflow.keras import models, layers
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.applications import Xception
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, Flatten, Input, Conv2D, multiply, LocallyConnected2D, Lambda, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import mean_absolute_error

def create_model():
    input_shape = (TARGET_SIZE, TARGET_SIZE, 3)
    in_lay = Input(input_shape)
    conv_base = Xception(include_top = False, weights = 'imagenet', input_shape = input_shape)
    pt_features = conv_base(in_lay)
    bn_features = BatchNormalization()(pt_features)

    # here we do an attention mechanism to turn pixels in the GAP on an off
    attn_layer = Conv2D(64, kernel_size = (1,1), padding = 'same', activation = 'relu')(bn_features)
    attn_layer = Conv2D(16, kernel_size = (1,1), padding = 'same', activation = 'relu')(attn_layer)
    attn_layer = LocallyConnected2D(1, kernel_size = (1,1), padding = 'valid', activation = 'sigmoid')(attn_layer)
    # fan it out to all of the channels
    pt_depth = conv_base.get_output_shape_at(0)[-1]
    up_c2_w = np.ones((1, 1, 1, pt_depth))
    up_c2 = Conv2D(pt_depth, kernel_size = (1,1), padding = 'same', 
                activation = 'linear', use_bias = False, weights = [up_c2_w])
    up_c2.trainable = False
    attn_layer = up_c2(attn_layer)

    mask_features = multiply([attn_layer, bn_features])
    gap_features = GlobalAveragePooling2D()(mask_features)
    gap_mask = GlobalAveragePooling2D()(attn_layer)
    # to account for missing values from the attention model
    gap = Lambda(lambda x: x[0]/x[1], name = 'RescaleGAP')([gap_features, gap_mask])
    gap_dr = Dropout(0.5)(gap)
    dr_steps = Dropout(0.25)(Dense(1024, activation = 'elu')(gap_dr))
    out_layer = Dense(11, activation = 'sigmoid')(dr_steps)
    model = Model(inputs = [in_lay], outputs = [out_layer])
    model.compile(optimizer = Adam(lr = 0.002), loss = 'binary_crossentropy', metrics = ["AUC"])
    return model


with tpu_strategy.scope():
    model = create_model()
model.summary()

history = model.fit(
    train_df,
    epochs = EPOCHS,
    steps_per_epoch = STEPS_PER_EPOCH,
    validation_data = valid_df,
    validation_steps = VALIDATION_STEPS
)

The resulting model summary:

Model: "model_8"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_19 (InputLayer)           [(None, 750, 750, 3) 0                                            
__________________________________________________________________________________________________
xception (Model)                (None, 24, 24, 2048) 20861480    input_19[0][0]                   
__________________________________________________________________________________________________
batch_normalization_49 (BatchNo (None, 24, 24, 2048) 8192        xception[1][0]                   
__________________________________________________________________________________________________
conv2d_67 (Conv2D)              (None, 24, 24, 64)   131136      batch_normalization_49[0][0]     
__________________________________________________________________________________________________
conv2d_68 (Conv2D)              (None, 24, 24, 16)   1040        conv2d_67[0][0]                  
__________________________________________________________________________________________________
locally_connected2d_9 (LocallyC (None, 24, 24, 1)    9792        conv2d_68[0][0]                  
__________________________________________________________________________________________________
conv2d_69 (Conv2D)              (None, 24, 24, 2048) 2048        locally_connected2d_9[0][0]      
__________________________________________________________________________________________________
multiply_9 (Multiply)           (None, 24, 24, 2048) 0           conv2d_69[0][0]                  
                                                                 batch_normalization_49[0][0]     
__________________________________________________________________________________________________
global_average_pooling2d_23 (Gl (None, 2048)         0           multiply_9[0][0]                 
__________________________________________________________________________________________________
global_average_pooling2d_24 (Gl (None, 2048)         0           conv2d_69[0][0]                  
__________________________________________________________________________________________________
RescaleGAP (Lambda)             (None, 2048)         0           global_average_pooling2d_23[0][0]
                                                                 global_average_pooling2d_24[0][0]
__________________________________________________________________________________________________
dropout_18 (Dropout)            (None, 2048)         0           RescaleGAP[0][0]                 
__________________________________________________________________________________________________
dense_17 (Dense)                (None, 1024)         2098176     dropout_18[0][0]                 
__________________________________________________________________________________________________
dropout_19 (Dropout)            (None, 1024)         0           dense_17[0][0]                   
__________________________________________________________________________________________________
dense_18 (Dense)                (None, 11)           11275       dropout_19[0][0]                 
==================================================================================================
Total params: 23,123,139
Trainable params: 23,062,467
Non-trainable params: 60,672
__________________________________________________________________________________________________

Full stacktrace + error message:

---------------------------------------------------------------------------
UnimplementedError                        Traceback (most recent call last)
<ipython-input-53-5130a0bcf331> in <module>
     19     validation_data = valid_df,
     20     validation_steps = VALIDATION_STEPS,
---> 21     callbacks = [model_save, early_stop, reduce_lr]
     22 )

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     67 
     68     # Running inside `run_distribute_coordinator` already.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    853                 context.async_wait()
    854               logs = tmp_logs  # No error, now safe to assign to logs.
--> 855               callbacks.on_train_batch_end(step, logs)
    856         epoch_logs = copy.copy(logs)
    857 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in on_train_batch_end(self, batch, logs)
    387     """
    388     if self._should_call_train_batch_hooks:
--> 389       logs = self._process_logs(logs)
    390       self._call_batch_hook(ModeKeys.TRAIN, 'end', batch, logs=logs)
    391 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py in _process_logs(self, logs)
    263     """Turns tensors into numpy arrays or Python scalars."""
    264     if logs:
--> 265       return tf_utils.to_numpy_or_python_type(logs)
    266     return {}
    267 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in to_numpy_or_python_type(tensors)
    521     return t  # Don't turn ragged or sparse tensors to NumPy.
    522 
--> 523   return nest.map_structure(_to_single_numpy_or_python_type, tensors)
    524 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in map_structure(func, *structure, **kwargs)
    615 
    616   return pack_sequence_as(
--> 617       structure[0], [func(*x) for x in entries],
    618       expand_composites=expand_composites)
    619 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/util/nest.py in <listcomp>(.0)
    615 
    616   return pack_sequence_as(
--> 617       structure[0], [func(*x) for x in entries],
    618       expand_composites=expand_composites)
    619 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/tf_utils.py in _to_single_numpy_or_python_type(t)
    517   def _to_single_numpy_or_python_type(t):
    518     if isinstance(t, ops.Tensor):
--> 519       x = t.numpy()
    520       return x.item() if np.ndim(x) == 0 else x
    521     return t  # Don't turn ragged or sparse tensors to NumPy.

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in numpy(self)
    959     """
    960     # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
--> 961     maybe_arr = self._numpy()  # pylint: disable=protected-access
    962     return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
    963 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _numpy(self)
    927       return self._numpy_internal()
    928     except core._NotOkStatusException as e:
--> 929       six.raise_from(core._status_to_exception(e.code, e.message), None)
    930 
    931   @property

/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

UnimplementedError: {{function_node __inference_train_function_644557}} Compilation failure: Dynamic Spatial Convolution is not supported: %convolution.30660 = f32[<=8,24,24,2048]{3,2,1,0} convolution(f32[<=8,24,24,1]{3,2,1,0} %add.30633, f32[1,1,1,2048]{3,2,1,0} %get-tuple-element.354), window={size=1x1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="model_8/conv2d_69/Conv2D"}
    TPU compilation failed
     [[{{node tpu_compile_succeeded_assert/_17367812259898276239/_5}}]]

0 Answers0