I have an issue with Tensorflow model that is converted from Pytorch -> Onnx -> Tensorflow. The issue is the converted Tensorflow model expects the input in Pytorch format that is (batch size, number channels, height, width) but not in Tensorflow format (batch size, height, width, number channel). Therefore, I cannot use the model to process further with Vitis AI.
So I would like to ask is there is any ways to convert this Pytorch input format to Tensorflow format by using tools from Onnx, Tensorflow 1, or others?
My code is as below:
Pytorch -> Onnx
from hardnet import hardnet
import torch
import onnx
ckpt = torch.load('../hardnet.pth')
model_state_dict = ckpt['model_state_dict']
optimizer_state_dict = ckpt['optimizer_state_dict']
model = hardnet(11)
model.load_state_dict(model_state_dict)
model.eval()
dummy_input = torch.randn(1, 3, 1080, 1920)
input_names = ['input0']
output_names = ['output0']
output_file = 'hardnet.onnx'
torch.onnx.export(model, dummy_input, output_file, verbose=True,
input_names=input_names, output_names=output_names,
opset_version=11, keep_initializers_as_inputs=True)
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
print('Passed Onnx')
Onnx -> Tensorflow 1 (using Tensorflow 1.15)
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import onnx
from onnx_tf.backend import prepare
output_file = 'hardnet.onnx'
onnx_model = onnx.load(output_file)
output = prepare(onnx_model)
output.export_graph('hardnet.pb')
tf.compat.v1.disable_eager_execution()
def load_pb(path_to_pb: str):
"""From: https://stackoverflow.com/questions/51278213/what-is-the-use-of-a-pb-file-in-tensorflow-and-how-does-it-work
"""
with tf.gfile.GFile(path_to_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
graph = load_pb('hardnet.pb')
input = graph.get_tensor_by_name('input0:0')
output = graph.get_tensor_by_name('output0:0')
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img = cv2.imread('train_0.jpg', cv2.IMREAD_COLOR)
img = cv2.resize(img, (1920, 1080))
img = img/255
img = img - mean
img = img/std
img = np.expand_dims(img, -1)
# To Pytorch format.
img = np.transpose(img, (3, 2, 0, 1))
img = img
with tf.Session(graph=graph) as sess:
pred = sess.run(output, {input: img})