I am looking for a similar feature as keras custom layer in ONNX/Onnxruntime. The way I understand to solve this is to implement a custom operator in onnx for experimentation. The documentation seems to be pointing to implementation in C++ as a shared library and use it in python. https://onnxruntime.ai/docs/reference/operators/add-custom-op.html
Is there a method to define custom op in python for onnx just for experimental purpose and use it for inferencing ? I tried following this but gives 'error: PyOp is not a registered function/op' https://onnxruntime.ai/docs/reference/operators/custom-python-operator.html
Python Code:
import onnx
import onnxruntime as ort
A = onnx.helper.make_tensor_value_info('A', onnx.TensorProto.FLOAT, [4])
B = onnx.helper.make_tensor_value_info('B', onnx.TensorProto.FLOAT, [4])
C = onnx.helper.make_tensor_value_info('C', onnx.TensorProto.FLOAT, [4])
D = onnx.helper.make_tensor_value_info('D', onnx.TensorProto.FLOAT, [4])
E = onnx.helper.make_tensor_value_info('E', onnx.TensorProto.FLOAT, [4])
F = onnx.helper.make_tensor_value_info('F', onnx.TensorProto.FLOAT, [4])
ad1_node = onnx.helper.make_node('Add', ['A', 'B'], ['S'])
mul_node = onnx.helper.make_node('Mul', ['C','D'], ['P'])
ad2_node = onnx.helper.make_node('Add', ['S', 'P'], ['H'])
py1_node = onnx.helper.make_node(op_type = 'PyOp', #required, must be 'PyOp'
inputs = ['H'], #required
outputs = ['F'], #required
domain = 'pyopadd_2', #required, must be unique
input_types = [onnx.TensorProto.FLOAT], #required
output_types = [onnx.TensorProto.FLOAT], #required
module = 'mymodule', #required
class_name = 'Add_2', #required
compute = 'compute') #optional, 'compute' by default
graph = onnx.helper.make_graph([ad1_node,mul_node,ad2_node, py1_node], 'multi_pyop_graph', [A,B,C,D], [F])
model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid('pyopadd_2', 1)], producer_name = 'pyop_model')
onnx.save(model, './modeltemp.onnx')
ort_session = ort.InferenceSession('./modeltemp.onnx')
ort_output = ort_session.run(["F"], {'A':[1,2,3,4], 'B':[1,1,1,1], 'C':[2,2,2,2], 'D':[3,3,3,3]})
print(ort_output)
mymodule.py
class Add_2:
def compute(self, S):
return S+2