If the shapes of all tensors in a tensorflow graph are well defined, I want to be able to extract the shape of all tensors. The tensorflow graph is saved as a protobuf file. Here is an example of it being made:
with tf.Graph().as_default() as graph:
a = tf.compat.v1.placeholder(tf.int32, shape=(4, 3), name='a')
b = tf.compat.v1.placeholder(tf.int32, shape=(4, 3), name='b')
c = tf.add(a, b, name='c')
with tf.compat.v1.Session(graph=graph) as sess:
graph_def = sess.graph.as_graph_def()
with open('simple_graph.pb', 'wb') as f:
f.write(graph_def.SerializeToString())
Now I want to extract the shape information of an arbitrary graph with well defined shapes. I tried:
void GetTensorShapes(const tensorflow::GraphDef &graph_def) {
for (const auto &node : graph_def.node()) {
const auto &shape_attr = node.attr().at("shape");
const tensorflow::TensorShapeProto &shape = shape_attr.shape();
std::cout << "Node name: " << node.name() << ", shape: ";
for (const auto &dim : shape.dim()) {
std::cout << dim.size() << " ";
}
std::cout << std::endl;
}
}
But shape() returns an empty vector. The equivalent in python would be:
def get_shapes(path)
graph_def = tf.GraphDef()
with open(path, 'rb') as f:
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
with tf.Session() as sess:
input_shapes = []
for op in graph.get_operations():
for output in op.outputs:
shape = output.shape
input_shapes.append([int(d.value) for d in shape.dims])
return input_shapes