I want to train a Tensorflow model in C++ using the Python/C API (I am aware of the C++ API of Tensorflow, but it's too restricted). First I create the model in Python and then I export it. Ater that I reload it in C++.
The problem: Unfortunately the Python wrapping in C++ seems to drop the default graph of tensorflow during restoring the Tensorflow session.
Here the working import code in Python:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
print("Default graph before:", tf.get_default_graph())
saver = tf.train.import_meta_graph("graph.pb.meta")
sess = tf.Session()
saver.restore(sess, "graph.pb")
print("Default graph after:", tf.get_default_graph())
This outputs the python graph object two times:
Default graph before: <tensorflow.python.framework.ops.Graph object at 0x7f382b762588>
Default graph after: <tensorflow.python.framework.ops.Graph object at 0x7f382b762588>
Now the not working C++ code (sry for this much code, each line in the upper python code is marked with a comment):
int main(){
// Init python
Py_Initialize();
std::cout << "Python version: " << Py_GetVersion() << std::endl;
// py: import tensorflow
PyObject *pName = PyUnicode_FromString("tensorflow");
PyObject *fModule = PyImport_Import(pName);
if(!fModule) std::cout << "Import tensorflow failed." << std::endl;
PyObject *pDict = PyModule_GetDict(fModule);
// py: print("Default graph before:", tf.get_default_graph())
PyObject* pGetDefaultGraph = PyDict_GetItemString(pDict, "get_default_graph");
PyObject* pDefaultGraph;
if(PyCallable_Check(pGetDefaultGraph)){
pDefaultGraph = PyObject_CallFunction(pGetDefaultGraph, 0);
std::cout << "Default graph before: ";
PyObject_Print(pDefaultGraph, stdout, 0);
std::cout << std::endl;
}
else std::cout << "tensorflow.get_default_graph() is not callable.";
// py: saver = tf.train.import_meta_graph("graph.pb.meta")
PyObject* trainModule = PyDict_GetItemString(pDict, "train");
PyObject* importMetaGraph = PyObject_GetAttrString(trainModule,"import_meta_graph");
PyObject* fSaver;
if(PyCallable_Check(importMetaGraph)){
fSaver = PyObject_CallFunction(importMetaGraph, "(s)", "graph.pb.meta");
}
else std::cout << "Cannot create Tensorflow saver from imported meta graph" << std::endl;
if(fSaver==0) std::cout << "Tensorflow model failed to load from file \"" << "graph.pb.meta" << "\"" << std::endl;
// py: sess = tf.Session()
PyObject* session = PyDict_GetItemString(pDict, "Session");
PyObject* fSession;
if(PyCallable_Check(session)){
fSession = PyObject_CallObject(session, 0);
}
else std::cout << "Cannot create Tensorflow session" << std::endl;
if(fSession==0) std::cout << "Tensorflow session points to zero, failed to create session" << std::endl;
// py: saver.restore(sess, "graph.pb")
PyObject_CallMethod(fSaver, "restore", "(0s)", fSession, "graph.pb");
// py: print("Default graph after:", tf.get_default_graph())
std::cout << "Default graph after: ";
pDefaultGraph = PyObject_CallFunction(pGetDefaultGraph, 0);
PyObject_Print(pDefaultGraph, stdout, 0);
std::cout << std::endl;
}
And here the problem, this results in this, where the default graph disappiers after restoring the session:
Python version: 3.5.1 (default, Mar 3 2016, 09:29:07) [GCC 5.3.0]
Default graph before: <tensorflow.python.framework.ops.Graph object at 0x7f3e8cd572b0>
Default graph after: <nil>
I know, weird question, but I am totally confused!