3

I trained a model with char-rnn-tensorflow (https://github.com/sherjilozair/char-rnn-tensorflow). The mode is saved into checkpoint. Now I want to serve the model with tensorflow serving.

Googled lots of tutorials about this, only found this meet my needs. When I change code as the tutorial to the below. It returns "node_name is not in graph" error.

Got the name of all nodes in graph with "[n.name for n in tf.get_default_graph().as_graph_def().node]", more than 10000 is crazy for me to figure out which one belongs to me.

So question here, is there any better method to find which node name I used when training. or any better solution to transform checkpoint to savemodel used in tensorlfow serving?

Thanks!

import tensorflow as tf
from model import Model
import argparse
import os
from six.moves import cPickle
from model import Model
#  Build the signature_def_map.
# X: ry, pkeep: 1.0, Hin: rh, batchsize: 1
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--data_dir', type=str, default='data/obama',
                        help='data directory containing input.txt')
    parser.add_argument('--output_node_names', type=str, default='node_name',
                        help='output node names')
    parser.add_argument('--output_graph', type=str, default='output_graph',
                        help='output_graph')
    parser.add_argument('--save_dir', type=str, default='save_train3',
                        help='directory to store checkpointed models')

    args = parser.parse_args()
    print(args)
    freeze_graph(args)

def freeze_graph(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    print(saved_args)
    model = Model(saved_args, training=False)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        print(tf.global_variables())
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        print(ckpt.model_checkpoint_path)
    # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True)
        # We restore the weights
        saver.restore(sess, ckpt.model_checkpoint_path)
        # We use a built-in TF helper to export variables to constants
        print(len([n.name for n in tf.get_default_graph().as_graph_def().node]))
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            args.output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 
        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(args.output_graph + "model.pb", "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

if __name__ == "__main__":
    main()
Tony Wang
  • 971
  • 4
  • 16
  • 33
  • To transform checkpoints to a model file , the method you followed is the best method I know.About node names , you can try see the graph of the model in Tensorboard and try figure our the relevant node names. – bsguru Feb 25 '18 at 05:47
  • @banguru Thanks for your reply. http://2ac561d5.ngrok.io/#graphs&run=2018-02-25-05-56-38 It's a tensorboard of this model. Do you have any suggestion for this output node names? Seems that, Adam is the last layer but didn't have any outputs. So it should be gradients? and the output_node_names here should be "global_norm", "clip_by_global_norm", "Adam/update_embedding/Unique"? – Tony Wang Feb 25 '18 at 06:02

0 Answers0