I am trying to convert a pre-trained torch model to ONNX, but recive the following error:
RuntimeError: step!=1 is currently not supported
I'm trying this on a pre-trained colorization model: https://github.com/richzhang/colorization
Here is the code I ran in Google Colab:
!git clone https://github.com/richzhang/colorization.git
cd colorization/
import colorizers
model = colorizer_siggraph17 = colorizers.siggraph17(pretrained=True).eval()
input_names = [ "input" ]
output_names = [ "output" ]
dummy_input = torch.randn(1, 1, 256, 256, device='cpu')
torch.onnx.export(model, dummy_input, "test_converted_model.onnx", verbose=True,
input_names=input_names, output_names=output_names)
I appreciate any help :)
UPDATE 1: @Proko suggestion solved the ONNX export issue. Now I have a new possibly related problem when I try to convert the ONNX to TensorRT. I get the following error:
[TensorRT] ERROR: Network must have at least one output
Here is the code I used:
import torch
import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt
import onnx
TRT_LOGGER = trt.Logger()
def build_engine(onnx_file_path):
# initialize TensorRT engine and parse ONNX model
builder = trt.Builder(TRT_LOGGER)
builder.max_workspace_size = 1 << 25
builder.max_batch_size = 1
if builder.platform_has_fast_fp16:
builder.fp16_mode = True
network = builder.create_network()
parser = trt.OnnxParser(network, TRT_LOGGER)
# parse ONNX
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
parser.parse(model.read())
print('Completed parsing of ONNX file')
# generate TensorRT engine optimized for the target platform
print('Building an engine...')
engine = builder.build_cuda_engine(network)
context = engine.create_execution_context()
print("Completed creating Engine")
return engine, context
ONNX_FILE_PATH = 'siggraph17.onnx' # Exported using the code above
engine,_ = build_engine(ONNX_FILE_PATH)
I tried to force the build_engine function to use the output of the network by:
network.mark_output(network.get_layer(network.num_layers-1).get_output(0))
but it did not work. I appropriate any help!