0

I have two models built with ONNX, model A and model B. I use the ONNX Runtime Java API to load these models and make inferences with them. The workflow is that I need to compute a prediction with model A and then feed the result from model A into model B:

x -> A(x) -> B(A(x)) -> y

When I call resultFromA = A.run(inputs)(OrtSession.run) the API returns a Result. Ideally I'd like to take that result and call B.run(resultFromA) but run only accepts Map<String, OnnxTensor> inputs. Do I really have to iterate over resultFromA and put its contents into a new Map? Or is their a method/usage from the API that I'm overlooking?

Here is what I'd like to do:

OrtEnvironment environment = OrtEnvironment.getEnvironment();
OrtSession modelASession = environment.createSession(...);
OrtSession modelBSession = environment.createSession(...);

try (Result modelAResults = modelASession.run(inputTensor)) {
    
    try(Result modelBResults = modelBSession.run(modelAResults) { // <-------run won't take a Result object
        //do something with model B results...
    }
}

And what I have to do instead:

OrtEnvironment environment = OrtEnvironment.getEnvironment();
OrtSession modelASession = environment.createSession(...);
OrtSession modelBSession = environment.createSession(...);

try (Result modelAResults = modelASession.run(inputTensor)) {

    Map<String, OnnxTensor> modelAMap = new HashMap<String, OnnxTensor>();
    modelAResults.forEach(e -> modelAMap.put(e.getKey(), (OnnxTensor) e.getValue()));
    
    try(Result modelBResults = modelBSession.run(modelAMap) {
        //do something with model B results...
    }
}
Joe
  • 418
  • 4
  • 12
  • try casting the Result class to Map – kiranr Feb 11 '21 at 07:32
  • Good idea but it doesn't work. Result implements AutoCloseable, Iterable>. The cast results in java.lang.ClassCastException: ai.onnxruntime.OrtSession$Result cannot be cast to java.util.Map – Joe Feb 11 '21 at 12:56
  • I checked the result class, they should've added a method that just returns a `map`. – kiranr Feb 11 '21 at 13:02
  • :-( I agree. I'll submit a bug report there. – Joe Feb 11 '21 at 16:21
  • Reported to onnx runtime github as a "feature request": https://github.com/microsoft/onnxruntime/issues/6655 – Joe Feb 11 '21 at 16:28
  • if know how to add that function to the code, create a pull request, and do it yourself its faster this way – kiranr Feb 11 '21 at 16:33
  • Let us [continue this discussion in chat](https://chat.stackoverflow.com/rooms/228586/discussion-between-waveshaper-and-joe). – kiranr Feb 11 '21 at 16:34

1 Answers1

0

you can use Collections.singletonMap like this:Collections.singletonMap(key, value) inside .run() to generate output

  • 1
    Your answer could be improved with additional supporting information. Please [edit] to add further details, such as citations or documentation, so that others can confirm that your answer is correct. You can find more information on how to write good answers [in the help center](/help/how-to-answer). – Community Jul 09 '22 at 17:46