4

I have exported my PyTorch model to ONNX. Now, is there a way for me to obtain the input layer from that ONNX model?

Exporting PyTorch model to ONNX

import torch.onnx
checkpoint = torch.load("./saved_pytorch_model.pth")
model.load_state_dict(checkpoint['state_dict'])
input = torch.tensor(df_X.values).float()
torch.onnx.export(model, input, "onnx_model.onnx")

Loading ONNX model

onnx_model = onnx.load('onnx_model.onnx')

I want to be able to somehow obtain the input layer from onnx_model. Is this possible?

Shawn Zhang
  • 1,719
  • 2
  • 14
  • 20

2 Answers2

2

The ONNX model is a protobuf structure, as defined here (https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto). You can work with it using the standard protobuf methods generated for python (see: https://developers.google.com/protocol-buffers/docs/reference/python-generated). I don't understand what exactly you want to extract. But you can iterate through the nodes that make up the graph (model.graph.node). The first node in the graph may or may not correspond to what you might consider the first layer (it depends on how the translation was done). You can also get the inputs of the model (model.graph.input).

G. Ramalingam
  • 181
  • 1
  • 5
2

Onnx library provides APIs to extract the names and shapes of all the inputs as follows:

model = onnx.load(onnx_model)
inputs = {}
for inp in model.graph.input:
    shape = str(inp.type.tensor_type.shape.dim)
    inputs[inp.name] = [int(s) for s in shape.split() if s.isdigit()]
AcidBurn
  • 199
  • 1
  • 11