I have already got the complete model by using pytorch, however I wanna convert the .pth file into .pb, which could be used in Tensorflow. Does anyone have some ideas?
2 Answers
You can use ONNX: Open Neural Network Exchange Format
To convert .pth
file to .pb
First, you need to export a model defined in PyTorch to ONNX and then import the ONNX model into Tensorflow (PyTorch => ONNX => Tensorflow)
This is an example of MNISTModel to Convert a PyTorch model to Tensorflow using ONNX from onnx/tutorials
Save the trained model to a file
torch.save(model.state_dict(), 'output/mnist.pth')
Load the trained model from file
trained_model = Net()
trained_model.load_state_dict(torch.load('output/mnist.pth'))
# Export the trained model to ONNX
dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model
torch.onnx.export(trained_model, dummy_input, "output/mnist.onnx")
Load the ONNX file
model = onnx.load('output/mnist.onnx')
# Import the ONNX model to Tensorflow
tf_rep = prepare(model)
Save the Tensorflow model into a file
tf_rep.export_graph('output/mnist.pb')
AS noted by @tsveti_iko in the comment
NOTE: The
prepare()
is build-in in theonnx-tf
, so you first need to install it through the console like thispip install onnx-tf
, then import it in the code like this:import onnx from onnx_tf.backend import prepare
and after that you can finally use it as described in the answer.

- 7,195
- 3
- 26
- 37
-
2NOTE: The `prepare()` is build-in in the `onnx-tf`, so you first need to install it through the console like this `pip install onnx-tf`, then import it in the code like this: `import onnx` `from onnx_tf.backend import prepare` and after that you can finally use it as described in the answer. – tsveti_iko Apr 29 '20 at 15:35
-
@tsveti_iko thanks to bring attention. i thought it should be in the answer rather in comment so i have added it in the answer. – Dishin H Goyani Apr 30 '20 at 06:15
-
indeed, you can even insert the code parts in your existing code blocks and remove the NOTE quote – tsveti_iko Apr 30 '20 at 14:47
If you are using TF 1.15 or below you might not find above code helpful because you would end-up solving miss-match version error
So here is all version matched code working for TF 1.X
Keras 2.3.0
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.2
numpy 1.21.5
onnx 1.8.0
onnx-tf 1.3.0
protobuf 3.19.4
tensorboard 1.15.0
tensorflow 1.15.0
tensorflow-estimator 1.15.1
torch 1.6.0+cpu
torchvision 0.7.0+cpu
After having all these packages use the answer by Dishin
Note: Variable
is depreciated in newer version of torch

- 5,128
- 3
- 21
- 32