10

When operating in graph mode in TF1, I believe I needed to wire up training=True and training=False via feeddicts when I was using the functional-style API. What is the proper way to do this in TF2?

I believe this is automatically handled when using tf.keras.Sequential. For example, I don't need to specify training in the following example from the docs:

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10, activation='softmax')
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))

Can I also assume that keras automagically handles this when training with the functional api? Here is the same model, rewritten using the function api:

inputs = tf.keras.Input(shape=((28,28,1)), name="input_image")
hid = tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1))(inputs)
hid = tf.keras.layers.MaxPooling2D()(hid)
hid = tf.keras.layers.Flatten()(hid)
hid = tf.keras.layers.Dropout(0.1)(hid)
hid = tf.keras.layers.Dense(64, activation='relu')(hid)
hid = tf.keras.layers.BatchNormalization()(hid)
outputs = tf.keras.layers.Dense(10, activation='softmax')(hid)
model_fn = tf.keras.Model(inputs=inputs, outputs=outputs)

# Model is the full model w/o custom layers
model_fn.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model_fn.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model_fn.evaluate(test_data)
print("Loss {:0.4f}, Accuracy {:0.4f}".format(loss, acc))

I'm unsure if hid = tf.keras.layers.BatchNormalization()(hid) needs to be hid = tf.keras.layers.BatchNormalization()(hid, training)?

A colab for these models can be found here.

cosentiyes
  • 364
  • 2
  • 16
  • Do you have a specific reason to want to control the training flag, or are you asking if its needed at all? – Dr. Snoopy Nov 06 '19 at 10:43
  • I guess I would want to be able to set it in a forward pass on `model_fn()` (`tf.keras.Model#call`) so that BatchNormalization behaves correctly. I assume I would need to subclass model and define the forward pass call explicitly so that I can pass `training` to the BN invocation, similarly to the example in https://www.tensorflow.org/api_docs/python/tf/keras/Model. I would also like to know if it is needed at all _when using `model_fn.fit()`_. – cosentiyes Nov 06 '19 at 11:46
  • @cosentiyes: You mentioned *I believe this is automatically handled when using `tf.keras.Sequential`*. Are you sure this is true? Do you have any reference which proves that? – Nerxis Apr 27 '20 at 14:02

2 Answers2

7

I realized that there is a bug in the BatchNormalization documentation [1] where the {{TRAINABLE_ATTRIBUTE_NOTE}} isn't actually replaced with the intended note [2]:

About setting layer.trainable = False on a BatchNormalization layer: The meaning of setting layer.trainable = False is to freeze the layer, i.e. its internal state will not change during training: its trainable weights will not be updated during fit() or train_on_batch(), and its state updates will not be run. Usually, this does not necessarily mean that the layer is run in inference mode (which is normally controlled by the training argument that can be passed when calling a layer). "Frozen state" and "inference mode" are two separate concepts.

However, in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch). This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case. Note that:

  • This behavior only occurs as of TensorFlow 2.0. In 1.*, setting layer.trainable = False would freeze the layer but would not switch it to inference mode.
  • Setting trainable on an model containing other layers will recursively set the trainable value of all inner layers.
  • If the value of the trainable attribute is changed after calling compile() on a model, the new value doesn't take effect for this model until compile() is called again.

[1] https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization?version=stable

[2] https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/keras/layers/normalization_v2.py#L26-L65

cosentiyes
  • 364
  • 2
  • 16
  • Thanks for your question and answer, I'm looking exactly for the same. What about `Dropout` layer? It's a bit different as it's about switching on (training) and off (inference). Do you know if this is handled (somehow) by default or do you need to deal with it by yourself? – Nerxis Apr 27 '20 at 14:34
6

As for the original broader question of whether you have to manually pass the training flag when using Keras Functional API, this example from the official docs suggests that you should not:

# ...

x = Dropout(0.5)(x)
outputs = Linear(10)(x)
model = tf.keras.Model(inputs, outputs)

# ...

# You can pass a `training` argument in `__call__`
# (it will get passed down to the Dropout layer).
y = model(tf.ones((2, 16)), training=True)
Ben Usman
  • 7,969
  • 6
  • 46
  • 66
  • 1
    I landed up here searching for an example like this. After that, i realized that i could make a simple test myself: https://colab.research.google.com/gist/jjclavijo/f216cb335fdd206bf68238553f9658b0/scratchpad.ipynb . May be it is a good complement to the answer. – Javier JC Feb 28 '22 at 16:59