I am trying to replace the computation done in the graph with a custom op that does the same.
Lets say the graph has a constant A
and weight variable W
, I create the custom op to take these two inputs and do the entire computation (except the last step of weight update):
custom_op_tensor = custom_module.custom_op([A,W])
g_def = tf.get_default_graph().as_graph_def()
input_map = { tensor.name : custom_op_tensor }
train_op, = tf.import_graph_def(g_def, input_map=input_map, return_elements=[train_op])
After the import graph def, there are two W
's, one from the original graph def and one in the imported graph. When we run the train op, the custom op ends up reading the old W
and the new W
is updated. As a result, the gradient descent ends up failing to do the right thing.
The problem is instantiation of custom_op requires the input weight tensor W
. The new W
is known only after the import. And, import requires the custom op.
How does one get around this problem ?