5

Note: All code for a self-contained example to reproduce my problem can be found below.

I have a tf.keras.models.Model() instance and would like to train that with a custom low-level TensorFlow API training loop. As part of this training loop, I need to make sure that my custom training loop updates all stateful variables from layer types such as tf.keras.layers.BatchNormalization. In order for this to happen, I understand from this answer by Francois Chollet that I need to evaluate model.updates in every training step.

The problem is: This works when you feed your training data to the model by using the feed_dict, but it isn't working when you use a tf.data.Dataset object.

Consider the following abstract example (you can find a concrete example to reproduce the problem below):

model = tf.keras.models.Model(...) # Some tf.keras model
dataset = tf.data.Dataset.from_tensor_slices(...) # Some tf.data.Dataset
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()

model_output = model(features)

with tf.Session() as sess:
    ret = sess.run(model.updates)

This sess.run() call throws the error

InvalidArgumentError: You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,224,224,3]

This error obviously shouldn't be raised. I don't need to feed a value for the placeholder input_1, because I'm calling my model on a tf.data.Dataset, not feeding input data to a placeholder via the feed_dict.

What can I do to make this work?

Here is a fully reproducible example. It's a simple image classifier being trained on Caltech256 (download the TFRecord files using the link at the bottom of this post):

import tensorflow as tf
from tqdm import trange
import sys
import glob
import os

sess = tf.Session()
tf.keras.backend.set_session(sess)

num_classes = 257
image_size = (224, 224, 3)

# Build a simple CNN with BatchNorm layers.

input_tensor = tf.keras.layers.Input(shape=image_size)
x = tf.keras.layers.Conv2D(64, (3,3), strides=(2,2), kernel_initializer='he_normal')(input_tensor)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(64, (3,3), strides=(2,2), kernel_initializer='he_normal')(x)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(128, (3,3), strides=(2,2), kernel_initializer='he_normal')(x)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(256, (3,3), strides=(2,2), kernel_initializer='he_normal')(x)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(num_classes, activation='softmax', kernel_initializer='he_normal')(x)
model = tf.keras.models.Model(input_tensor, x)

# We'll monitor whether the moving mean and moving variance of the first BatchNorm layer is being updated as it should.
moving_mean = tf.reduce_mean(model.layers[2].moving_mean)
moving_variance = tf.reduce_mean(model.layers[2].moving_variance)

# Build a tf.data.Dataset from TFRecords.

tfrecord_directory = '/path/to/the/tfrecord/files/'

tfrecord_filennames = glob.glob(os.path.join(tfrecord_directory, '*.tfrecord'))

feature_schema = {'image': tf.FixedLenFeature([], tf.string),
                  'filename': tf.FixedLenFeature([], tf.string),
                  'label': tf.FixedLenFeature([], tf.int64)}

dataset = tf.data.Dataset.from_tensor_slices(tfrecord_filennames)
dataset = dataset.shuffle(len(tfrecord_filennames)) # Shuffle the TFRecord file names.
dataset = dataset.flat_map(lambda filename: tf.data.TFRecordDataset(filename))
dataset = dataset.map(lambda single_example_proto: tf.parse_single_example(single_example_proto, feature_schema)) # Deserialize tf.Example objects.
dataset = dataset.map(lambda sample: (sample['image'], sample['label']))
dataset = dataset.map(lambda image, label: (tf.image.decode_jpeg(image, channels=3), label)) # Decode JPEG images.
dataset = dataset.map(lambda image, label: (tf.image.resize_image_with_pad(image, target_height=image_size[0], target_width=image_size[1]), label))
dataset = dataset.map(lambda image, label: (tf.image.per_image_standardization(image), label))
dataset = dataset.map(lambda image, label: (image, tf.one_hot(indices=label, depth=num_classes))) # Convert labels to one-hot format.
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.repeat()
dataset = dataset.batch(32)

iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()

# Build the training-relevant part of the graph.

model_output = model(batch_features)

loss = tf.reduce_mean(tf.keras.backend.categorical_crossentropy(target=batch_labels, output=model_output, from_logits=False))

train_step = tf.train.AdamOptimizer().minimize(loss)

# The next block is for the metrics.
with tf.variable_scope('metrics') as scope:
    predictions_argmax = tf.argmax(model_output, axis=-1, output_type=tf.int64)
    labels_argmax = tf.argmax(batch_labels, axis=-1, output_type=tf.int64)
    mean_loss_value, mean_loss_update_op = tf.metrics.mean(loss)
    acc_value, acc_update_op = tf.metrics.accuracy(labels=labels_argmax, predictions=predictions_argmax)
    local_metric_vars = tf.contrib.framework.get_variables(scope=scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
    metrics_reset_op = tf.variables_initializer(var_list=local_metric_vars, name='metrics_reset_op')

# Run the training.

epochs = 3
steps_per_epoch = 1000

fetch_list = [mean_loss_value,
              acc_value,
              moving_mean,
              moving_variance,
              train_step,
              mean_loss_update_op,
              acc_update_op] + model.updates

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

with sess.as_default():

    for epoch in range(1, epochs+1):

        tr = trange(steps_per_epoch, file=sys.stdout)
        tr.set_description('Epoch {}/{}'.format(epoch, epochs))

        sess.run(metrics_reset_op)

        for train_step in tr:

            ret = sess.run(fetches=fetch_list, feed_dict={tf.keras.backend.learning_phase(): 1})

            tr.set_postfix(ordered_dict={'loss': ret[0],
                                         'accuracy': ret[1],
                                         'bn1 moving mean': ret[2],
                                         'bn1 moving variance': ret[3]})

Running this code throws the error described above:

InvalidArgumentError: You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,224,224,3]

A very shitty work-around to circumvent this problem would be to fetch the next batch via a separate sess.run() call and then feed the fetched Numpy arrays to a second sess.run() call via the feed_dict. This works, but it obviously partially defeats the purpose of using the tf.data API:

# Build the training-relevant part of the graph.

labels = tf.placeholder(dtype=tf.float32, shape=(None, num_classes), name='labels')

loss = tf.reduce_mean(tf.keras.backend.categorical_crossentropy(target=labels, output=model.output, from_logits=False))

train_step = tf.train.AdamOptimizer().minimize(loss)

with tf.variable_scope('metrics') as scope:
    predictions_argmax = tf.argmax(model.output, axis=-1, output_type=tf.int64)
    labels_argmax = tf.argmax(labels, axis=-1, output_type=tf.int64)
    mean_loss_value, mean_loss_update_op = tf.metrics.mean(loss)
    acc_value, acc_update_op = tf.metrics.accuracy(labels=labels_argmax, predictions=predictions_argmax)
    local_metric_vars = tf.contrib.framework.get_variables(scope=scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
    metrics_reset_op = tf.variables_initializer(var_list=local_metric_vars, name='metrics_reset_op')

# Run the training. With BatchNorm.

epochs = 3
steps_per_epoch = 1000

fetch_list = [mean_loss_value,
              acc_value,
              moving_mean,
              moving_variance,
              train_step,
              mean_loss_update_op,
              acc_update_op] + model.updates

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

with sess.as_default():

    for epoch in range(1, epochs+1):

        tr = trange(steps_per_epoch, file=sys.stdout)
        tr.set_description('Epoch {}/{}'.format(epoch, epochs))

        sess.run(metrics_reset_op)

        for train_step in tr:

            b_images, b_labels = sess.run([batch_features, batch_labels])

            ret = sess.run(fetches=fetch_list, feed_dict={tf.keras.backend.learning_phase(): 1,
                                                          model.input: b_images,
                                                          labels: b_labels})

            tr.set_postfix(ordered_dict={'loss': ret[0],
                                         'accuracy': ret[1],
                                         'bn1 moving mean': ret[2],
                                         'bn1 moving variance': ret[3]})

As mentioned above, this is just a bad work-around. How can I make this work properly?

You can download the TFRecord files here.

Alex
  • 3,316
  • 4
  • 26
  • 52
  • Instead of defining a new session, shouldn't you get the current session from Keras backend, using K.get_session() (K is the keras backend) and then run your model.updates? The current session would already have the inputs fed through the dataset object right? – kvish Feb 10 '19 at 15:59
  • @kvish thanks for the suggestion. I tried it, but it doesn't solve the problem :/ – Alex Feb 10 '19 at 19:58
  • maybe you can use the make_initializable_iterator instead of one_shot_iterator? You need to run the iterator initializer, then run model.updates? Here is the [official tensorflow guide](https://www.tensorflow.org/guide/datasets#creating_an_iterator) covering how it works. – kvish Feb 10 '19 at 20:15
  • @kvish worth a try, but could you elaborate on how exactly you would do this? I can add the `input_1` placeholder to the `feed_dict` of the iterator initializer, but what value do I pass to it? – Alex Feb 15 '19 at 19:11
  • check the example for [creating an iterator](https://www.tensorflow.org/guide/datasets#creating_an_iterator). You would make iterator = dataset.make_initializable_iterator() in your code, and then get the session from Keras backend using sess = K.get_session(), and simply run the iterator initializer with sess.run(iterator.initializer). Then you get your model.updates. You do not need to define any placeholder. – kvish Feb 15 '19 at 23:45
  • @kvish I tried your suggestion, but it doesn't work either. I'm also not quite sure I understand the idea behind it. I've updated my question with a fully reproducible example. – Alex Feb 16 '19 at 00:06
  • Thanks for the update @Alex. I will try the code today! I just had a quick glance of your code. Just curious, any particular reason for using tf optimizers and metrics with Keras? – kvish Feb 16 '19 at 00:24
  • @kvish no strong reason, I'm just more familiar with the TF ops. I don't know, for example, how I would use the tf.keras.optimizers.Adam optimizer in a low-level TF training loop. Overall I need the low-level TF training loop, because I need to add operations to the training-relevant part of the graph and fetch their outputs during the training (the norms of the gradients of all layers). This is just inconvenient to do with Keras' model.compile() and model.fit() methods. But as long as I can use a low-level TF training loop, I'd be happy for it to consist only of tf.keras objects. – Alex Feb 16 '19 at 00:33
  • 1
    Have you already tried wrapping your TensorFlow tensor into an Input layer. At least that is something they did here https://stackoverflow.com/a/46140332/7482962 The documentation may be also helpful her https://keras.io/layers/core/#input. In short, the layer is necessary to convert the tf tensor into a keras tensor. Otherwise there is also the possibility to convert the keras model into a tf.Estimator, but I wouldn't recommend this approach to be honest. – p13rr0m Feb 16 '19 at 00:48
  • 1
    @pierrom that is exactly right! Of course, why didn't I have this idea! All it took was setting `input_tensor = tf.keras.layers.Input(tensor=batch_features)`. Thanks a lot, 50 reputation are yours if you post an answer :) – Alex Feb 16 '19 at 01:00

1 Answers1

2

The problem is this line:

model_output = model(batch_features)

It's generally fine to call a model on a tensor, but in this case it causes problems. When the model was created, its input layer created a placeholder tensor that wants to be fed when you call model.updates. Instead of calling the model on the batch_features tensor, you should instead set the model's input layer to build upon batch_features (instead of creating a placeholder) when you create it. That is, you need to set the right input at model instantiation, afterwards it's too late. This is done like so:

input_tensor = tf.keras.layers.Input(tensor=batch_features)

Now running model.updates works just fine.

Alex
  • 3,316
  • 4
  • 26
  • 52