11

I trained one model and then create one .pb file by freeze that model. so, my question is how to get weights from .pb file or i have to do more process for get weights

@mrry, please guide me.

mrgloom
  • 20,061
  • 36
  • 171
  • 301
  • Unfortunately I'm not mrry, but freezing a model gets you a GraphDef; you can [parse a GraphDef in Python](https://www.tensorflow.org/extend/tool_developers/#graphdef), which will have the values of constants (including your frozen weights). – Allen Lavoie Sep 11 '17 at 17:34
  • ohk.. thank you so much .. –  Sep 13 '17 at 05:10

1 Answers1

22

Let us first load the graph from .pb file.

import tensorflow as tf
from tensorflow.python.platform import gfile

GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
with tf.Session(config=config) as sess:
  print("load graph")
  with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
    graph_nodes=[n for n in graph_def.node]

Now when you freeze a graph to .pb file your variables are converted to Const type and the weights which were trainabe variables would also be stored as Const in .pb file. graph_nodes contains all the nodes in graph. But we are interested in all the Const type nodes.

wts = [n for n in graph_nodes if n.op=='Const']

Each element of wts is of NodeDef type. It has several atributes such as name, op etc. The values can be extracted as follows -

from tensorflow.python.framework import tensor_util

for n in wts:
    print "Name of the node - %s" % n.name
    print "Value - " 
    print tensor_util.MakeNdarray(n.attr['value'].tensor)

Hope this solves your concern.

Krist
  • 477
  • 4
  • 17