3

I am completely lost on the the tensorflow saver method.

I'm trying to follow the basic tensorflow deep neural network model tutorial. I want to figure out how to train the network for a few iterations, then load the model in another session.

with tf.Session() as sess:
    graph = tf.Graph()
    x = tf.placeholder(tf.float32,shape=[None,784])
    y_ = tf.placeholder(tf.float32, shape=[None,10])

    sess.run(global_variables_initializer())

    #Define the Network
    #(This part is all copied from the tutorial - not copied for brevity)
    #See here: https://www.tensorflow.org/versions/r0.12/tutorials/mnist/pros/

Skipping ahead to training.

    #Train the Network
    train_step = tf.train.AdamOptimizer(1e-4).minimize(
                     cross_entropy,global_step=global_step)
    correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

    saver = tf.train.Saver()

    for i in range(101):
        batch = mnist.train.next_batch(50)
        if i%100 == 0:
        train_accuracy = accuracy.eval(feed_dict=
                           {x:batch[0],y_:batch[1]})
        print 'Step %d, training accuracy %g'%(i,train_accuracy)
            train_step.run(feed_dict={x:batch[0], y_: batch[1]})
        if i%100 == 0:
            print 'Test accuracy %g'%accuracy.eval(feed_dict={x: 
                       mnist.test.images, y_: mnist.test.labels})

        saver.save(sess,'./mnist_model')

The console prints out:

Step 0, training accuracy 0.16

Test accuracy 0.0719

Step 100, training accuracy 0.88

Test accuracy 0.8734

Next I want to load the model

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('mnist_model.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    sess.run(tf.global_variables_initializer())

Now I want to re-test to see if the model loaded

print 'Test accuracy %g'%accuracy.eval(feed_dict={x: 
                       mnist.test.images, y_: mnist.test.labels})

The console prints out:

Test accuracy 0.1151

It doesn't appear that the model is saving any of the data? What am I doing wrong?

Carbon Rod
  • 33
  • 4
  • You shouldn't run `sess.run(tf.global_variables_initializer())` after restoring weights. This will reset all your weights – martianwars Jun 05 '17 at 23:37

1 Answers1

4

When you save your models, typically all global variables are saved in external files whereas local variables are not. You can have a look at this answer to understand the difference.

The error in your restoration code is calling tf.global_variable_initializer() after saver.restore(). The saver.restore docs mention,

The variables to restore do not have to have been initialized, as restoring is itself a way to initialize variables.

Hence, try removing the line,

sess.run(tf.global_variables_initializer())

You should ideally replace it with,

sess.run(tf.local_variables_initializer())
martianwars
  • 6,380
  • 5
  • 35
  • 44
  • 1
    Thanks, this certainly seems to have solved my issue! If the documents state that `saver.restore()` is an initialization process, does `sess.run(tf.local_variables_initializer())` serve any purpose? This also seems to suggest that tutorials such as [A quick complete tutorial to save and restore Tensorflow models](http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/) show incorrect usage, does it not? – Carbon Rod Jun 06 '17 at 12:58
  • You should check [`tf.local_variables()`](https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/local_variables). It is needed if this list is non-empty – martianwars Jun 06 '17 at 14:46