I want to wrap the attention-OCR model with OpenCV-DNN to increase inference time. I am using the TF code from the official TF models repo.
For wrapping TF model with OpenCV-DNN, I am referring to this code. The cv2.dnn.readNetFromTensorflow()
requires 'frozen graph' and 'graph structure' to read a TF model.
I use this code snippet to import structure from meta checkpoint file and save the graph structure in a .pbtxt
file.
# load graph from meta file
tf.reset_default_graph()
imported_meta = tf.train.import_meta_graph("attention_ocr_2017_08_09/model_demo_inference.ckpt.meta")
# restore graph structure, variables in session's graph
sess = tf.Session()
imported_meta.restore(sess, 'attention_ocr_2017_08_09/model_demo_inference.ckpt')
# write graph structure to a pbtxt file
tf.train.write_graph(sess.graph_def, './', 'train_attention.pbtxt', as_text=True)
To freeze the graph, the code is as follows:
from tensorflow.python.tools import freeze_graph
freeze_graph.freeze_graph('train_attention.pbtxt', '', False, \
'attention_ocr_2017_08_09/model_demo_inference.ckpt', \
'AttentionOcr_v1_1/Softmax', \
'save/restore_all', 'save/Const:0', 'frozen_model.pb', True, "")
The final code uses the pbtxt
and pb
files in the cv2.dnn.readNetFromTensorflow()
function.
# Wrap TF model in OpenCV DNN
import cv2
FROZEN_GRAPH = "frozen_model.pb"
PB_TXT = "train_attention.pbtxt"
img = cv2.imread('testdata/fsns_train_00.png')
blob = cv2.dnn.blobFromImage(img,1)
net = cv2.dnn.readNetFromTensorflow(FROZEN_GRAPH, PB_TXT)
out = net.forward()
out
The error encountered is:
---------------------------------------------------------------------------
error Traceback (most recent call last)
<ipython-input-128-09e46e8b88ed> in <module>
9 blob = cv2.dnn.blobFromImage(img,1)
10
---> 11 net = cv2.dnn.readNetFromTensorflow(FROZEN_GRAPH, PB_TXT)
12 out = net.forward()
13 out
error: OpenCV(4.0.0) /Users/travis/build/skvark/opencv-python/opencv/modules/dnn/src/
tensorflow/tf_io.cpp:54: error: (-2:Unspecified error)
FAILED: ReadProtoFromTextFile(param_file, param).
Failed to parse GraphDef file: train_attention.pbtxt in function 'ReadTFNetParamsFromTextFileOrDie'
Note: The output node name is set manually by looking at the list of tensors in graph generated using:
# get names of all tensors
def get_names(graph=sess.graph):
return [t.name for op in graph.get_operations() for t in op.values()]
l1 = get_names()
for ele in l1:
print(ele)
I would greatly appreciate any help provided by the SO community.