In following the API from https://onnxruntime.ai/docs/api/python/api_summary.html, the section on running data on device states that "Users can use the get_outputs() API to get access to the OrtValue (s) corresponding to the allocated output(s). Users can thus consume the ONNX Runtime allocated memory for the output as an OrtValue." How do I ensure that I can actually see what is inside the OrtValues for validation?
The following script:
def run_with_torch_tensors_on_device(x: torch.Tensor, CURR_SIZE: int, torch_type: torch.dtype = torch.float) -> torch.Tensor:
binding = session.io_binding()
x_tensor = x.contiguous()
z_tensor = torch.zeros((CURR_SIZE, 91), dtype=torch_type, device=DEVICE).contiguous()
binding.bind_input(
name=session.get_inputs()[0].name,
device_type=DEVICE_NAME,
device_id=DEVICE_INDEX,
element_type=np.float32,
buffer_ptr=x_tensor.data_ptr(),
shape=x_tensor.shape)
binding.bind_output(
name=session.get_outputs()[-1].name,
device_type=DEVICE_NAME,
device_id=DEVICE_INDEX)
session.run_with_iobinding(binding)
return binding.get_outputs()[0]
Simply returns:
<onnxruntime.capi.onnxruntime_inference_collection.OrtValue object at 0x7f652612ae60>
I need to keep the data on device, so I can not call the .numpy()
as this would invoke device-host communication.