One of the major problems I've encountered when converting PyTorch models to TensorFlow through ONNX, is slowness, which appears to be related to the input shape, even though I was able to get bit-exact outputs with the two frameworks.
While the PyTorch input shape is B,C,H,W
, the Tensorflow input shape is B,H,W,C
, where B,C,H,W
stand for batch size, channels, height and width, respectively. Technically, I solve the input shape problem easily when working in Tensorflow, using two calls to np.swapaxes
:
# Single image, no batch size here yet
image = np.swapaxes(image, 0, 2) # Swapping C and H dimensions - result: C,W,H
image = np.swapaxes(image, 1, 2) # Swapping H and W dimensions - result: C,H,W (like Pytorch)
The slowness problem seems to be related to the differences in the ways the convolutional operations are implemented in PyTorch vs Tensorflow. While PyTorch expects channels first, Tensorflow expects channels last.
As a result, when I visualize the models using Netron, the ONNX model looks abstract and making sense (first image), whereas the Tensorflow .pb
formatted model looks like a big mess (second image).
Note: It appears that this problem has already concerned the writers of the onnx2keras library, which supports an experimental feature of changing the C,H,W
ordering originated in Pytorch, into H,W,C
.
Any idea how to overcome this limitation? Are there other options for more abstractly exporting PyTorch models into Tensorflow?
ONNX (from PyTorch) - you can see the straight flow and the residual blocks:
Tensorflow (imported from the ONNX model) - almost nothing looks like a series of predefined operations: