I am currently training a large object detection model in Tensorflow 2 with a custom training loop using gradient tape. The problem is that the model is not improving the loss as the gradients are very low. I reproduced the problem on a simple classification task using cifar10 and discovered, that a small model is training fine with no problem while a larger model (VGG16) is not improving the loss at all. Below is some code for reproducing the problem.
VGG16 model:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Dropout, MaxPooling2D, BatchNormalization, Input, Concatenate
import os
def create_vgg16(number_classes, include_fully=True, input_shape=(300, 300, 3), input_tensor=None):
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
img_input = input_tensor
x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv1_1')(img_input)
x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv1_2')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool1')(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv2_1')(x)
x = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv2_2')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool2')(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv3_1')(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv3_2')(x)
x = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv3_3')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool3')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv4_1')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv4_2')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv4_3')(x)
x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool4')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv5_1')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv5_2')(x)
x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv5_3')(x)
x = MaxPooling2D(pool_size=(3, 3), strides=(1, 1), padding='same', name='pool5')(x)
if include_fully:
x = Flatten(name='flatten')(x)
x = Dense(4096, activation='relu', name='fc1')(x)
x = Dense(4096, activation='relu', name='fc2')(x)
x = Dense(number_classes, activation='softmax', name='predictions')(x)
if input_tensor is not None:
inputs = tf.keras.utils.get_source_inputs(input_tensor)
else:
inputs = img_input
model = tf.keras.models.Model(inputs, x, name='vgg16')
return model
Small CNN model:
def create_small_cnn(n_classes, input_shape=(32, 32, 3)):
img_input = tf.keras.Input(shape=input_shape)
x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1_1')(img_input)
x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1_2')(x)
x = tf.keras.layers.Flatten(name='flatten')(x)
x = tf.keras.layers.Dense(16, activation='relu', name='fc1')(x)
x = tf.keras.layers.Dense(n_classes, activation='softmax', name='softmax')(x)
model = tf.keras.Model(img_input, x, name='small_cnn')
return model
Training loop:
def main():
number_classes = 10
# Load and one hot encode data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train, x_test
y_train = tf.reshape(y_train, [-1])
y_train = tf.one_hot(y_train, number_classes).numpy()
y_test = tf.reshape(y_test, [-1])
y_test = tf.one_hot(y_test, number_classes).numpy()
# Define model
model = create_vgg16(number_classes, input_shape=(32, 32, 3))
# model = create_small_cnn(number_classes, input_shape=(32, 32, 3))
# Instantiate an optimizer to train the model.
optimizer = tf.keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
# Instantiate a loss function.s
loss_fn = tf.keras.losses.CategoricalCrossentropy()
# Prepare the metrics.
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()
# Prepare the training dataset.
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices(
(tf.cast(x_train/255, tf.float32),
tf.cast(y_train,tf.int64)))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices(
(tf.cast(x_test/255, tf.float32),
tf.cast(y_test,tf.int64)))
val_dataset = val_dataset.shuffle(buffer_size=1024).batch(batch_size)
model.summary()
for epoch in range(100):
print('Start of epoch %d' % (epoch,))
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch_train)
loss_value = loss_fn(y_batch_train, logits)
grads = tape.gradient(loss_value, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_acc_metric(y_batch_train[0], logits[0][:-1])
if step % 200 == 0:
print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))
# Display metrics at the end of each epoch.
train_acc = train_acc_metric.result()
print('Training acc over epoch: %s' % (float(train_acc),))
# Reset training metrics at the end of each epoch
train_acc_metric.reset_states()
# Run a validation loop at the end of each epoch.
for x_batch_val, y_batch_val in val_dataset:
val_logits = model(x_batch_val)
val_acc_metric(y_batch_val[0], val_logits[0][:-1])
val_acc = val_acc_metric.result()
val_acc_metric.reset_states()
print('Validation acc: %s' % (float(val_acc),))
if __name__ == '__main__':
main()
If you run the code shown you will see the network training fine while using the small CNN model. But on the other hand it does not work on the exact same dataset with the same preprocessing using a standard VGG16 model. To make matters more confusing, the VGG model will train perfectly fine when using model.fit instead of custom training loop with gradient tape.
Does anybody have an idea why this is the case and how to fix this problem?