12

I have already got the complete model by using pytorch, however I wanna convert the .pth file into .pb, which could be used in Tensorflow. Does anyone have some ideas?

Rafael
  • 141
  • 1
  • 1
  • 5

2 Answers2

12

You can use ONNX: Open Neural Network Exchange Format

To convert .pth file to .pb First, you need to export a model defined in PyTorch to ONNX and then import the ONNX model into Tensorflow (PyTorch => ONNX => Tensorflow)

This is an example of MNISTModel to Convert a PyTorch model to Tensorflow using ONNX from onnx/tutorials

Save the trained model to a file

torch.save(model.state_dict(), 'output/mnist.pth')

Load the trained model from file

trained_model = Net()
trained_model.load_state_dict(torch.load('output/mnist.pth'))

# Export the trained model to ONNX
dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model
torch.onnx.export(trained_model, dummy_input, "output/mnist.onnx")

Load the ONNX file

model = onnx.load('output/mnist.onnx')

# Import the ONNX model to Tensorflow
tf_rep = prepare(model)

Save the Tensorflow model into a file

tf_rep.export_graph('output/mnist.pb')

AS noted by @tsveti_iko in the comment

NOTE: The prepare() is build-in in the onnx-tf, so you first need to install it through the console like this pip install onnx-tf, then import it in the code like this: import onnx from onnx_tf.backend import prepare and after that you can finally use it as described in the answer.

Dishin H Goyani
  • 7,195
  • 3
  • 26
  • 37
  • 2
    NOTE: The `prepare()` is build-in in the `onnx-tf`, so you first need to install it through the console like this `pip install onnx-tf`, then import it in the code like this: `import onnx` `from onnx_tf.backend import prepare` and after that you can finally use it as described in the answer. – tsveti_iko Apr 29 '20 at 15:35
  • @tsveti_iko thanks to bring attention. i thought it should be in the answer rather in comment so i have added it in the answer. – Dishin H Goyani Apr 30 '20 at 06:15
  • indeed, you can even insert the code parts in your existing code blocks and remove the NOTE quote – tsveti_iko Apr 30 '20 at 14:47
0

If you are using TF 1.15 or below you might not find above code helpful because you would end-up solving miss-match version error
So here is all version matched code working for TF 1.X

Keras                2.3.0
Keras-Applications   1.0.8
Keras-Preprocessing  1.1.2
numpy                1.21.5
onnx                 1.8.0
onnx-tf              1.3.0
protobuf             3.19.4
tensorboard          1.15.0
tensorflow           1.15.0
tensorflow-estimator 1.15.1
torch                1.6.0+cpu
torchvision          0.7.0+cpu

After having all these packages use the answer by Dishin

Note: Variable is depreciated in newer version of torch

Prajot Kuvalekar
  • 5,128
  • 3
  • 21
  • 32