0

In following the API from https://onnxruntime.ai/docs/api/python/api_summary.html, the section on running data on device states that "Users can use the get_outputs() API to get access to the OrtValue (s) corresponding to the allocated output(s). Users can thus consume the ONNX Runtime allocated memory for the output as an OrtValue." How do I ensure that I can actually see what is inside the OrtValues for validation?

The following script:

def run_with_torch_tensors_on_device(x: torch.Tensor, CURR_SIZE: int, torch_type: torch.dtype = torch.float) -> torch.Tensor:
    binding = session.io_binding()
    x_tensor = x.contiguous()
    z_tensor = torch.zeros((CURR_SIZE, 91), dtype=torch_type, device=DEVICE).contiguous()

    binding.bind_input(
        name=session.get_inputs()[0].name,
        device_type=DEVICE_NAME,
        device_id=DEVICE_INDEX,
        element_type=np.float32,
        buffer_ptr=x_tensor.data_ptr(),
        shape=x_tensor.shape)

    binding.bind_output(
        name=session.get_outputs()[-1].name,
        device_type=DEVICE_NAME,
        device_id=DEVICE_INDEX)

    session.run_with_iobinding(binding)

    return binding.get_outputs()[0]

Simply returns:

<onnxruntime.capi.onnxruntime_inference_collection.OrtValue object at 0x7f652612ae60>

I need to keep the data on device, so I can not call the .numpy() as this would invoke device-host communication.

JOKKINATOR
  • 356
  • 1
  • 11

0 Answers0