3

One of the major problems I've encountered when converting PyTorch models to TensorFlow through ONNX, is slowness, which appears to be related to the input shape, even though I was able to get bit-exact outputs with the two frameworks.

While the PyTorch input shape is B,C,H,W, the Tensorflow input shape is B,H,W,C, where B,C,H,W stand for batch size, channels, height and width, respectively. Technically, I solve the input shape problem easily when working in Tensorflow, using two calls to np.swapaxes:

# Single image, no batch size here yet
image = np.swapaxes(image, 0, 2)   # Swapping C and H dimensions - result: C,W,H
image = np.swapaxes(image, 1, 2)   # Swapping H and W dimensions - result: C,H,W (like Pytorch)

The slowness problem seems to be related to the differences in the ways the convolutional operations are implemented in PyTorch vs Tensorflow. While PyTorch expects channels first, Tensorflow expects channels last.

As a result, when I visualize the models using Netron, the ONNX model looks abstract and making sense (first image), whereas the Tensorflow .pb formatted model looks like a big mess (second image).

Note: It appears that this problem has already concerned the writers of the onnx2keras library, which supports an experimental feature of changing the C,H,W ordering originated in Pytorch, into H,W,C.

Any idea how to overcome this limitation? Are there other options for more abstractly exporting PyTorch models into Tensorflow?

ONNX (from PyTorch) - you can see the straight flow and the residual blocks:

enter image description here

Tensorflow (imported from the ONNX model) - almost nothing looks like a series of predefined operations:

enter image description here

SomethingSomething
  • 11,491
  • 17
  • 68
  • 126
  • Intuitively, the input shape should not be the point of the problem. Can you provide more details about your code (both pytorch and tensorflow)? – zhf061 Dec 31 '19 at 16:20
  • @zhf061 I thought the input shape could be a problem, since if Tensorflow expects `H,W,C`, then its convolutional layers may be implemented accordingly. When working on Tensorflow with an originally Pytorch model that I ported using ONNX (which still expects Pytorch's shape of `C,H,W`), I need to first reshape the input to match the network/Pytorch original requirements. That's why I think ONNX cannot really use the high-level convolutional layers API of Tensorflow and therefore it uses some non-optimized series of operations instead – SomethingSomething Jan 01 '20 at 07:11
  • Sorry for being not familiar with onnx. I found two links, [link](https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html), [link](https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowImport.ipynb), and I hope they can give you a clue. – zhf061 Jan 01 '20 at 13:49
  • @zhf061, thanks. It appears that the problem is related to the Tensorflow graph not being frozen. Currently cannot figure out how to do that when the model is imported from ONNX – SomethingSomething Jan 08 '20 at 13:34

0 Answers0