5

In my neural network, I create some tf.Variable objects as follows:

weights = {
    'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
    'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
    'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
    'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

How would I save the variables in weights and biases after a specific number of iterations without saving other variables?

mrry
  • 125,488
  • 26
  • 399
  • 400
lhao0301
  • 1,941
  • 4
  • 20
  • 26
  • There are quite a few different ways of doing it. You can use a tensorflow saver or use your favorite format like h5 or npy. – jeandut Sep 12 '16 at 12:16

1 Answers1

7

The standard way to save variables in TensorFlow is to use a tf.train.Saver object. By default it saves all of the variables in your problem (i.e., the results of tf.all_variables()), but you can save variables selectively by passing the var_list optional argument to the tf.train.Saver constructor:

weights = {
    'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
    'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
    'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
    'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

# Define savers for explicit subsets of the variables.
weights_saver = tf.train.Saver(var_list=weights)
biases_saver = tf.train.Saver(var_list=biases)

# ...
# You need a TensorFlow Session to save variables.
sess = tf.Session()
# ...

# ...then call the following methods as appropriate:
weights_saver.save(sess)  # Save the current value of the weights.
biases_saver.save(sess)   # Save the current value of the biases.

Note that if you pass a dictionary to the tf.train.Saver constructor (such as the weights and/or biases dictionaries from your question), TensorFlow will use the dictionary key (e.g. 'wc1_0') as the name for the corresponding variable in any checkpoint files it creates or consumes.

By default, or if you pass a list of tf.Variable objects to the constructor, TensorFlow will use the tf.Variable.name property instead.

Passing a dictionary gives you the ability to share checkpoints between models that give different Variable.name properties to each variable. This detail is only important if you want to use the created checkpoints with another model.

mrry
  • 125,488
  • 26
  • 399
  • 400