0

The following code that uses the TF-Slim library to load a model and finetune it achieves a performance of 90% in a classification task (I omitted loading the data and preprocessing):

with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=0.0001)):
    logits, _ = resnet_v1.resnet_v1_50(images, num_classes=dataset.num_classes, is_training=True)

one_hot_labels = slim.one_hot_encoding(labels, NUM_CLASSES)
tf.losses.softmax_cross_entropy(one_hot_labels, logits)
total_loss = tf.losses.get_total_loss()
global_step = variables.get_or_create_global_step()
lr = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, GAMMA)
optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM)
train_op = slim.learning.create_train_op(total_loss, optimizer, global_step=global_step)
init_fn = slim.assign_from_checkpoint_fn("resnet_v1_50.ckpt", VARIABLES_TO_RESTORE)

final_loss = slim.learning.train( train_op, logdir=train_dir, log_every_n_steps=500, save_summaries_secs=25,  init_fn=init_fn, number_of_steps = NUM_STEPS)

I tried rewriting the same code using vanilla tensorflow to have more control over the training process and for some reason I cannot achieve the same performance (10% performance drop) when using all the same hyperparameters (in uppercase) and same preprocessing. The differences are in the graph definition:

        lr = tf.train.exponential_decay(LEARNING_RATE,  global_step, DECAY_STEPS, GAMMA)
        optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM)
        full_train_op = optimizer.minimize(total_loss, global_step=global_step)

and training:

for s in range(NUM_STEPS):
    sess.run(train_init_op) #Initializes dataset iterator
    while True:
        try:
            sess.run([full_train_op], feed_dict={is_training: True})                    
        except tf.errors.OutOfRangeError:
            break

Is the slim train function doing some other operations? I thought it might be using batch normalization or something else that I did not implement on my version of the code.

Is it possible to load the slim resnet model in tensorflow and train it without the slim train function? I am not interested in overriding train_step_fn.

1 Answers1

0

This may be due to not running update_ops associated with resnet's batch norm.

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=MOMENTUM)
with tf.control_dependencies(update_ops):
    full_train_op = optimizer.minimize(total_loss, global_step)
# same training loop
DomJack
  • 4,098
  • 1
  • 17
  • 32
  • Thank you for your help, would I need to do anything extra for inference during validation? ''' while True: try: sess.run(tf_metric_update, feed_dict={is_training: False}) except tf.errors.OutOfRangeError: break train_acc = sess.run(tf_metric, feed_dict={is_training: False}) ''' where tf_metric is defined in the graph as tf_metric, tf_metric_update = tf.metrics.accuracy(labels, predictions, name="accuracy_metric") – Francisco Salgado Jun 18 '18 at 07:01
  • Nope. Update ops only run during training. So long as `is_training` is False it'll use the values now updated during training rather than the batch statistics. Without running the update ops evaluation uses the initial moving average values. – DomJack Jun 19 '18 at 08:09