I use the randomforest estimator, implemented in tensorflow, to predict if a text is english or not. I saved my model (A dataset with 2k samples and 2 class labels 0/1 (Not English/English)) using the following code (train_input_fn function return features and class labels):
model_path='test/'
TensorForestEstimator(params, model_dir='model/')
estimator.fit(input_fn=train_input_fn, max_steps=1)
After running the above code, the graph.pbtxt and checkpoints are saved in the model folder. Now I want to use it on Android. I have 2 problems:
As the first step, I need to freeze the graph and checkpoints to a .pb file to use it on Android. I tried freeze_graph (I used the code here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). When I call the freeze_graph in my mode, I get the following error and the code cannot create the final .pb graph:
File "/Users/XXXXXXX/freeze_graph.py", line 105, in freeze_graph _ = tf.import_graph_def(input_graph_def, name="") File "/anaconda/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 258, in import_graph_def op_def = op_dict[node.op] KeyError: u'CountExtremelyRandomStats'
this is how I call freeze_graph:
def save_model_android():
checkpoint_state_name = "model.ckpt-1"
input_graph_name = "graph.pbtxt"
output_graph_name = "output_graph.pb"
checkpoint_path = os.path.join(model_path, checkpoint_state_name)
input_graph_path = os.path.join(model_path, input_graph_name)
input_saver_def_path = None
input_binary = False
output_node_names = "output"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(model_path, output_graph_name)
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path,
output_node_names, restore_op_name,
filename_tensor_name, output_graph_path,
clear_devices, "")
I also tried the freezing on the iris dataset in "tf.contrib.learn.datasets.load_iris". I get the same error. So I believe it is not related to the dataset.
- As a second step, I need to use the .pb file on the phone to predict a text. I found the camera demo example by google and it contains a lot of code. I wonder if there is a step by step tutorial how to use a Tensorflow model on Android by passing a feature vector and get the class label.
Thanks, in advance!
UPDATE
By using the recent version of tensorflow (0.12), the problem is solved. However, now, the problem is that what I should pass to output_node_names ??? How can I get what are the output nodes in the graph ?