I am trying to convert pytorch model with multiple networks to ONNX, and encounter some problem.
The git repo: https://github.com/InterDigitalInc/HRFAE
The Trainer Class:
class Trainer(nn.Module):
def __init__(self, config):
super(Trainer, self).__init__()
# Load Hyperparameters
self.config = config
# Networks
self.enc = Encoder()
self.dec = Decoder()
self.mlp_style = Mod_Net()
self.dis = Dis_PatchGAN()
...
Here is how the trained model process image:
def gen_encode(self, x_a, age_a, age_b=0, training=False, target_age=0):
if target_age:
self.target_age = target_age
age_modif = self.target_age*torch.ones(age_a.size()).type_as(age_a)
else:
age_modif = self.random_age(age_a, diff_val=25)
# Generate modified image
self.content_code_a, skip_1, skip_2 = self.enc(x_a)
style_params_a = self.mlp_style(age_a)
style_params_b = self.mlp_style(age_modif)
x_a_recon = self.dec(self.content_code_a, style_params_a, skip_1, skip_2)
x_a_modif = self.dec(self.content_code_a, style_params_b, skip_1, skip_2)
return x_a_recon, x_a_modif, age_modif
And as following is how I did to convert to onnx:
enc = Encoder()
dec = Decoder()
mlp = Mod_Net()
layers = [enc, mlp, dec]
model = torch.nn.Sequential(*layers)
# here is my confusion: how do I specify the inputs of each layer??
# E.g. one of the outputs of 'enc' layer should be input of 'mlp' layer,
# or the outputs of 'enc' layer should be part of inputs of 'dec' layer...
params = torch.load('./logs/001/checkpoint')
model[0].load_state_dict(params['enc_state_dict'])
model[1].load_state_dict(params['mlp_style_state_dict'])
model[2].load_state_dict(params['dec_state_dict'])
torch.onnx.export(model, torch.randn([1, 3, 1024, 1024]), 'trained_hrfae.onnx', do_constant_folding=True)
Maybe the convert-part code is in wrong way?? Could anyone help, many thanks!
#20210629-11:52GMT Edit:
I found there's constraint of using torch.nn.Sequential. The output of former layer in Sequential should be consistent with latter input. So my code shouldn't work at all because the output of 'enc' layer is not consistent with input of 'mlp' layer.
Could anyone help how to convert this type of pytorch model to onnx? Many thanks, again :)