I have two models built with ONNX, model A and model B. I use the ONNX Runtime Java API to load these models and make inferences with them. The workflow is that I need to compute a prediction with model A and then feed the result from model A into model B:
x -> A(x) -> B(A(x)) -> y
When I call resultFromA = A.run(inputs)
(OrtSession.run) the API returns a Result.
Ideally I'd like to take that result and call B.run(resultFromA)
but run
only accepts Map<String, OnnxTensor> inputs
. Do I really have to iterate over resultFromA
and put its contents into a new Map
? Or is their a method/usage from the API that I'm overlooking?
Here is what I'd like to do:
OrtEnvironment environment = OrtEnvironment.getEnvironment();
OrtSession modelASession = environment.createSession(...);
OrtSession modelBSession = environment.createSession(...);
try (Result modelAResults = modelASession.run(inputTensor)) {
try(Result modelBResults = modelBSession.run(modelAResults) { // <-------run won't take a Result object
//do something with model B results...
}
}
And what I have to do instead:
OrtEnvironment environment = OrtEnvironment.getEnvironment();
OrtSession modelASession = environment.createSession(...);
OrtSession modelBSession = environment.createSession(...);
try (Result modelAResults = modelASession.run(inputTensor)) {
Map<String, OnnxTensor> modelAMap = new HashMap<String, OnnxTensor>();
modelAResults.forEach(e -> modelAMap.put(e.getKey(), (OnnxTensor) e.getValue()));
try(Result modelBResults = modelBSession.run(modelAMap) {
//do something with model B results...
}
}