1

On Tensorflow 1.12 I'm trying to edit a variable, but even after running the assignment operation, the updated value is not being saved.

import tensorflow as tf

tf.reset_default_graph()

g = tf.Graph()
with g.as_default(), tf.Session(graph=g) as sess:
    w = tf.Variable(2,name = "VARIABLE")
    sess.run(tf.global_variables_initializer())
    y = sess.run(w)
    print('initial value', y)
    ww = tf.assign(w, 3)
    y = sess.run(ww)
    print('changed value', y)

    saver = tf.train.Saver()
    save_path = saver.save(sess, './test_1')
    

model_path = "./test_1.meta"
g = tf.Graph()
with g.as_default():
    sess = tf.Session(graph=g)
    saver = tf.train.import_meta_graph(model_path, clear_devices=True)
    saver.restore(sess, model_path.replace('.meta', ''))
    sess.run(tf.global_variables_initializer())
    gb = tf.global_variables()
    print(gb)

    w = gb[0]
    y = sess.run(w)
    print('loaded value', y)

outputs

('initial value', 2)
('changed value', 3)
INFO:tensorflow:Restoring parameters from ./test_1
[<tf.Variable 'VARIABLE:0' shape=() dtype=int32_ref>]
('loaded value', 2)

where the loaded value is not giving the expected value 3.

Innat
  • 16,113
  • 6
  • 53
  • 101
Ryan Halabi
  • 95
  • 4
  • 17

1 Answers1

1

I executed this code in Tensorflow 1.x and found the same error, then I observed that if we dont use 'global_variables_initializer'("#sess.run(tf.compat.v1.global_variables_initializer()") after restoring the model, we get the updated variable(y=3) otherwise it takes first global initailized variable value which is y=2.

Find below replicated code and output:

import tensorflow as tf

tf.compat.v1.reset_default_graph()

g = tf.Graph()
with g.as_default(), tf.compat.v1.Session(graph=g) as sess:
    w = tf.Variable(2,name = "VARIABLE")
    sess.run(tf.compat.v1.global_variables_initializer())
    y = sess.run(w)
    print('initial value', y)

    ww = tf.compat.v1.assign(w, 3)
    y = sess.run(ww)
    print('changed value', y)

    saver = tf.compat.v1.train.Saver()
    save_path = saver.save(sess, './test_1')
    

model_path = "./test_1.meta"

g = tf.Graph()
with g.as_default():
    sess = tf.compat.v1.Session(graph=g)
    saver = tf.compat.v1.train.import_meta_graph(model_path, clear_devices=True)
    saver.restore(sess, model_path.replace('.meta', ''))
    #sess.run(tf.compat.v1.global_variables_initializer())
    gb = tf.compat.v1.global_variables()
    print(gb)

    w = gb[0]
    y = sess.run(w)
    print('loaded value', y)

Output:

initial value 2
changed value 3
INFO:tensorflow:Restoring parameters from ./test_1
[<tf.Variable 'VARIABLE:0' shape=() dtype=int32>]
loaded value 3

For more information, Please click here