I have the follow code to convert from pytorch model to onnx.
#Function to Convert to ONNX
import argparse
import io
import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
from torchvision import models
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmaction.models import build_model
from mmcv import Config
from mmaction.utils import (build_ddp, build_dp, default_device)
def parse_args():
parser = argparse.ArgumentParser(description='MMAction2')
parser.add_argument(
'--conf',
default='/workspace/mmaction2/work_dirs/slowfast_kinetics_pretrained_r50_8x8x1_cosine_10e_ava22_rgb/slowfast_kinetics_pretrained_r50_8x8x1_cosine_10e_ava22_rgb.py',
help='This is config file')
parser.add_argument(
'--checkpoint',
default='/workspace/mmaction2/work_dirs/slowfast_kinetics_pretrained_r50_8x8x1_cosine_10e_ava22_rgb/latest.pth',
help='This is checkpoint file')
args = parser.parse_args()
return args
def turn_off_pretrained(cfg):
# recursively find all pretrained in the model config,
# and set them None to avoid redundant pretrain steps for testing
if 'pretrained' in cfg:
cfg.pretrained = None
# recursively turn off pretrained value
for sub_cfg in cfg.values():
if isinstance(sub_cfg, dict):
turn_off_pretrained(sub_cfg)
def convert(args):
cfg = Config.fromfile(args.conf)
turn_off_pretrained(cfg.model)
# build the model and load checkpoint
model = build_model(cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
load_checkpoint(model, args.checkpoint, map_location='cpu')
model = build_dp(model, default_device, default_args=dict(device_ids=cfg.gpu_ids))
# set the model to inference mode
model.eval()
#device = torch.device("cpu" if args.cpu else "cuda")
#model = model.to(device)
# Let's create a dummy input tensor
dummy_input = torch.randn(1, 3, 1080, 1920).to(default_device)
# Export the model
torch.onnx.export(model.module, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
"slow_fast_ava.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
print(" ")
print('Model has been converted to ONNX')
if __name__ == '__main__':
args = parse_args()
convert(args )
Pytorch model is attached here.
The error in conversion is
File "mmaction2/demo/convert_to_onnx_1.py", line 57, in convert
torch.onnx.export(model.module, # model being run
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 350, in export
return utils.export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 163, in export
_export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1074, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 727, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 602, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 517, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 1175, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
result = self.forward(*input, **kwargs)
File "/workspace/mmcv/mmcv/runner/fp16_utils.py", line 116, in new_func
return old_func(*args, **kwargs)
File "/workspace/mmdetection/mmdet/models/detectors/base.py", line 168, in forward
assert len(img_metas) == 1
File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 705, in __len__
raise TypeError("len() of a 0-d tensor")
TypeError: len() of a 0-d tensor