5

Let's start at the beginning. So far I have created and trained small networks in Tensorflow myself. During the training I save my model and get the following files in my directory:

model.ckpt.meta
model.ckpt.index
model.ckpt.data-00000-of-00001

Later, I load the model saved in network_dir to do some classifications and extract the trainable variables of my model.

saver = tf.train.import_meta_graph(network_dir + ".meta")
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="NETWORK")

Now I want to work with larger pretrained models like the VGG16 or ResNet and want to use my code to do that. I want to load pretrained models like my own networks as shown above.

On this site, I found many pretrained models:

https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models

I downloaded the VGG16 checkpoint and realized that these are only the trained parameters.

I would like to know how or where I can get the saved model or graph structure of these pretrained network? How do I use, for example, the VGG16 checkpoint without model.ckpt.meta, model.ckpt.index and the model.ckpt.data-00000-of-00001 files?

Gilfoyle
  • 3,282
  • 3
  • 47
  • 83

1 Answers1

2

Next to the weights link, there is link to the code that defines the model. For instance, for VGG16: Code. Create the model using the code and restore variables from the checkpoint:

import tensorflow as tf

slim = tf.contrib.slim

image = ...  # Define your input somehow, e.g with placeholder
logits, _ = vgg.vgg_16(image)
predictions = tf.argmax(logits, 1)
variables_to_restore = slim.get_variables_to_restore()

saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
    saver.restore(sess, "/path/to/model.ckpt")

So, the code contained in vgg.py will create all the variables for you. Using the tf-slim helper, you can get the list. Then, just follow the usual procedure. There was a similar question on this.

Dmytro Prylipko
  • 4,762
  • 2
  • 25
  • 44
  • How do I do that? Can you please be more specific. I am a beginner. – Gilfoyle Feb 08 '19 at 16:01
  • Thank you very much. I imported the VGG16 code doing `import vgg`. However, I get the following error `NameError: name 'slim' is not defined`. How can I solve this problem to get all the variables of the network? – Gilfoyle Feb 13 '19 at 09:54
  • 1
    `slim = tf.contrib.slim` – Dmytro Prylipko Feb 13 '19 at 10:15
  • Sorry, for asking such a dumb question. – Gilfoyle Feb 13 '19 at 10:15
  • Can you please show how I can extract all activations of the VGG16. Normally I get them by doing `activations=tf.get_collection('Activations')`, but this seems not possible here. What is the right way to get the activations? – Gilfoyle Feb 13 '19 at 10:35