2

I am currently training a large object detection model in Tensorflow 2 with a custom training loop using gradient tape. The problem is that the model is not improving the loss as the gradients are very low. I reproduced the problem on a simple classification task using cifar10 and discovered, that a small model is training fine with no problem while a larger model (VGG16) is not improving the loss at all. Below is some code for reproducing the problem.

VGG16 model:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, Dropout, MaxPooling2D, BatchNormalization, Input, Concatenate
import os

def create_vgg16(number_classes, include_fully=True, input_shape=(300, 300, 3), input_tensor=None):
    if input_tensor is None:
        img_input = Input(shape=input_shape)
    else:
        img_input = input_tensor
    x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv1_1')(img_input)
    x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv1_2')(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool1')(x)

    x = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv2_1')(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv2_2')(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool2')(x)

    x = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv3_1')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv3_2')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv3_3')(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool3')(x)

    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv4_1')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv4_2')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv4_3')(x)
    x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same', name='pool4')(x)

    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv5_1')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv5_2')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal', name='conv5_3')(x)
    x = MaxPooling2D(pool_size=(3, 3), strides=(1, 1), padding='same', name='pool5')(x)

    if include_fully:
        x = Flatten(name='flatten')(x)
        x = Dense(4096, activation='relu', name='fc1')(x)
        x = Dense(4096, activation='relu', name='fc2')(x)
        x = Dense(number_classes, activation='softmax', name='predictions')(x)

    if input_tensor is not None:
        inputs = tf.keras.utils.get_source_inputs(input_tensor)
    else:
        inputs = img_input
    model = tf.keras.models.Model(inputs, x, name='vgg16')

    return model

Small CNN model:

def create_small_cnn(n_classes, input_shape=(32, 32, 3)):
    img_input = tf.keras.Input(shape=input_shape)
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1_1')(img_input)
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1_2')(x)
    x = tf.keras.layers.Flatten(name='flatten')(x)
    x = tf.keras.layers.Dense(16, activation='relu', name='fc1')(x)
    x = tf.keras.layers.Dense(n_classes, activation='softmax', name='softmax')(x)

    model = tf.keras.Model(img_input, x, name='small_cnn')
    return model

Training loop:

def main():
    number_classes = 10
    # Load and one hot encode data
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    x_train, x_test = x_train, x_test
    y_train = tf.reshape(y_train, [-1])
    y_train = tf.one_hot(y_train, number_classes).numpy()
    y_test = tf.reshape(y_test, [-1])
    y_test = tf.one_hot(y_test, number_classes).numpy()


    # Define model
    model = create_vgg16(number_classes, input_shape=(32, 32, 3))
    # model = create_small_cnn(number_classes, input_shape=(32, 32, 3))

    # Instantiate an optimizer to train the model.
    optimizer = tf.keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)

    # Instantiate a loss function.s
    loss_fn = tf.keras.losses.CategoricalCrossentropy()

    # Prepare the metrics.
    train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
    val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

    # Prepare the training dataset.
    batch_size = 64
    train_dataset = tf.data.Dataset.from_tensor_slices(
      (tf.cast(x_train/255, tf.float32),
       tf.cast(y_train,tf.int64)))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

    # Prepare the validation dataset.
    val_dataset = tf.data.Dataset.from_tensor_slices(
      (tf.cast(x_test/255, tf.float32),
       tf.cast(y_test,tf.int64)))
    val_dataset = val_dataset.shuffle(buffer_size=1024).batch(batch_size)

    model.summary()

    for epoch in range(100):
      print('Start of epoch %d' % (epoch,))
      for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
          logits = model(x_batch_train)
          loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        train_acc_metric(y_batch_train[0], logits[0][:-1])

        if step % 200 == 0:
            print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))


      # Display metrics at the end of each epoch.
      train_acc = train_acc_metric.result()
      print('Training acc over epoch: %s' % (float(train_acc),))
      # Reset training metrics at the end of each epoch
      train_acc_metric.reset_states()

      # Run a validation loop at the end of each epoch.
      for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val)

        val_acc_metric(y_batch_val[0], val_logits[0][:-1])
      val_acc = val_acc_metric.result()
      val_acc_metric.reset_states()
      print('Validation acc: %s' % (float(val_acc),))


if __name__ == '__main__':
    main()

If you run the code shown you will see the network training fine while using the small CNN model. But on the other hand it does not work on the exact same dataset with the same preprocessing using a standard VGG16 model. To make matters more confusing, the VGG model will train perfectly fine when using model.fit instead of custom training loop with gradient tape.

Does anybody have an idea why this is the case and how to fix this problem?

Maximilian
  • 21
  • 2

0 Answers0