3

I am trying to convert pytorch model with multiple networks to ONNX, and encounter some problem.

The git repo: https://github.com/InterDigitalInc/HRFAE

The Trainer Class:

class Trainer(nn.Module):
    def __init__(self, config):
        super(Trainer, self).__init__()
        # Load Hyperparameters
        self.config = config
        # Networks
        self.enc = Encoder()
        self.dec = Decoder()
        self.mlp_style = Mod_Net()
        self.dis = Dis_PatchGAN()
        ...

Here is how the trained model process image:

    def gen_encode(self, x_a, age_a, age_b=0, training=False, target_age=0):
        if target_age:
            self.target_age = target_age
            age_modif = self.target_age*torch.ones(age_a.size()).type_as(age_a)
        else:
            age_modif = self.random_age(age_a, diff_val=25)

        # Generate modified image
        self.content_code_a, skip_1, skip_2 = self.enc(x_a)
        style_params_a = self.mlp_style(age_a)
        style_params_b = self.mlp_style(age_modif)
        
        x_a_recon = self.dec(self.content_code_a, style_params_a, skip_1, skip_2)
        x_a_modif = self.dec(self.content_code_a, style_params_b, skip_1, skip_2)
        
        return x_a_recon, x_a_modif, age_modif

And as following is how I did to convert to onnx:

enc = Encoder()
dec = Decoder()
mlp = Mod_Net()
layers = [enc, mlp, dec]
model = torch.nn.Sequential(*layers)  
# here is my confusion: how do I specify the inputs of each layer?? 
# E.g. one of the outputs of 'enc' layer should be input of 'mlp' layer, 
# or the outputs of 'enc' layer should be part of inputs of 'dec' layer...

params = torch.load('./logs/001/checkpoint')  
model[0].load_state_dict(params['enc_state_dict'])
model[1].load_state_dict(params['mlp_style_state_dict'])
model[2].load_state_dict(params['dec_state_dict'])

torch.onnx.export(model, torch.randn([1, 3, 1024, 1024]), 'trained_hrfae.onnx', do_constant_folding=True)  

Maybe the convert-part code is in wrong way?? Could anyone help, many thanks!

#20210629-11:52GMT Edit:

I found there's constraint of using torch.nn.Sequential. The output of former layer in Sequential should be consistent with latter input. So my code shouldn't work at all because the output of 'enc' layer is not consistent with input of 'mlp' layer.

Could anyone help how to convert this type of pytorch model to onnx? Many thanks, again :)

ZZ Shao
  • 83
  • 1
  • 1
  • 9
  • Please do not bury the error message in the code itself; post the complete error trace - see how to create a [mre]. – desertnaut Jun 29 '21 at 11:34
  • @desertnaut thanks for the tip. but i think the main problem is the code itself. the error message is just for comments. I'll keep in mind for future. – ZZ Shao Jun 29 '21 at 11:40
  • I am not following; of course the problem is in the code (where else could it be?). You may want to keep it in mind just for the future, but also please keep in mind that, as is, your question is eligible for closure for lack of details and/or an MRE. – desertnaut Jun 29 '21 at 11:45
  • OK I got it. I mean, the error message is used to help explain the situation, so I take them in comments. Right now I find some other issue of the code, i'm going to edit it. At the same thme I'll remove the error in the code as you suggested. Thanks. – ZZ Shao Jun 29 '21 at 11:51

1 Answers1

1

After research and try, I found a method which maybe in correct way:

Convert each net(Encoder, Mod_Net, Decoder) to onnx model, and handle their input/output in latter logic-process or any further procedure (e.g convert to tflite model).

I'm trying to port onto Android using this method.

#Edit 20210705-03:52GMT#

Another approach may be better: write a new net combines the three nets. I've prove the output is same as origin pytorch model.

class HRFAE(nn.Module):
    def __init__(self):
        super(HRFAE, self).__init__()
        self.enc = Encoder()
        self.mlp_style = Mod_Net()
        self.dec = Decoder()

    def forward(self, x, age_modif):
        content_code_a, skip_1, skip_2 = self.enc(x)
        style_params_b = self.mlp_style(age_modif)

        x_a_modif = self.dec(content_code_a, style_params_b, skip_1, skip_2)

        return x_a_modif

and then convert use following:

net = HRFAE()

params = torch.load('./logs/002/checkpoint')
net.enc.load_state_dict(params['enc_state_dict'])
net.mlp_style.load_state_dict(params['mlp_style_state_dict'])
net.dec.load_state_dict(params['dec_state_dict'])

net.eval()
torch.onnx.export(net, (torch.randn([1, 3, 512, 512]), torch.randn([1]).type(torch.long)), 'test_hrfae.onnx')

This should be the answer.

ZZ Shao
  • 83
  • 1
  • 1
  • 9