I define a deep CNN with tensorflow, inluding a batch-normalization op, i.e, my code may look like this:
def network(input):
...
input = tf.layers.batch_normalization(input, ...)
...
Assume the network has been trained, and the checkpoint file has been saved. Now I would like to use this model for inference. Normally, I can call the function network(input)
again, except for passing parameter training=False
to tf.layers.batch_normalization()
, then restore weights from the checkpoint file.
However, I would prefer to use tf.import_meta_graph
to rebuild my network, since the code in function network(input)
can be changed.
But now how I can set the batch-normalization op in inference mode this time? Since I have no access to function tf.layers.batch_normalization()
, it's a little difficult for me to work this out.