0

If the shapes of all tensors in a tensorflow graph are well defined, I want to be able to extract the shape of all tensors. The tensorflow graph is saved as a protobuf file. Here is an example of it being made:

with tf.Graph().as_default() as graph:
    a = tf.compat.v1.placeholder(tf.int32, shape=(4, 3), name='a')
    b = tf.compat.v1.placeholder(tf.int32, shape=(4, 3), name='b')
    c = tf.add(a, b, name='c')

with tf.compat.v1.Session(graph=graph) as sess:
    graph_def = sess.graph.as_graph_def()
    with open('simple_graph.pb', 'wb') as f:
        f.write(graph_def.SerializeToString())

Now I want to extract the shape information of an arbitrary graph with well defined shapes. I tried:

void GetTensorShapes(const tensorflow::GraphDef &graph_def) {
    for (const auto &node : graph_def.node()) {
            const auto &shape_attr = node.attr().at("shape");
            const tensorflow::TensorShapeProto &shape = shape_attr.shape();
            std::cout << "Node name: " << node.name() << ", shape: ";
            for (const auto &dim : shape.dim()) {
                std::cout << dim.size() << " ";
            }
            std::cout << std::endl;
    }
}

But shape() returns an empty vector. The equivalent in python would be:

def get_shapes(path)
    graph_def = tf.GraphDef()
    with open(path, 'rb') as f:
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)

    with tf.Session() as sess:
        input_shapes = []
        for op in graph.get_operations():
            for output in op.outputs:
                shape = output.shape
                input_shapes.append([int(d.value) for d in shape.dims])
    
    return input_shapes
Dan8757
  • 53
  • 1
  • 5

0 Answers0