1

I'm running FastGAN (https://github.com/odegeasslbc/FastGAN-pytorch) on Google Colab and now trying to resume training from a saved .pth generated by the network. However, it keeps throwing this error:

Traceback (most recent call last):
  File "train.py", line 202, in 
    train(args)
  File "train.py", line 117, in train
    netG.load_state_dict(ckpt['g'])
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1407, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict: "init.init.0.weight_orig", "init.init.0.weight", "init.init.0.weight_u", "init.init.0.weight_orig", "init.init.0.weight_u", "init.init.0.weight_v", "init.init.1.weight", "init.init.1.bias", "init.init.1.running_mean", "init.init.1.running_var", "feat_8.1.weight_orig", "feat_8.1.weight", "feat_8.1.weight_u", "feat_8.1.weight_orig", "feat_8.1.weight_u", "feat_8.1.weight_v", "feat_8.2.weight", "feat_8.3.weight", "feat_8.3.bias", "feat_8.3.running_mean", "feat_8.3.running_var", "feat_8.5.weight_orig", "feat_8.5.weight", "feat_8.5.weight_u", "feat_8.5.weight_orig", "feat_8.5.weight_u", "feat_8.5.weight_v", "feat_8.6.weight", "feat_8.7.weight", "feat_8.7.bias", "feat_8.7.running_mean", "feat_8.7.running_var", "feat_16.1.weight_orig", "feat_16.1.weight", "feat_16.1.weight_u", "feat_16.1.weight_orig", "feat_16.1.weight_u", "feat_16.1.weight_v", "feat_16.2.weight", "feat_16.2.bias", "feat_16.2.running_mean", "feat_16.2.running_var", "feat_32.1.weight_orig", "feat_32.1.weight", "feat_32.1.weight_u", "feat_32.1.weight_orig", "feat_32.1.weight_u", "feat_32.1.weight_v", "feat_32.2.weight", "feat_32.3.weight", "feat_32.3.bias", "feat_32.3.running_mean", "feat_32.3.running_var", "feat_32.5.weight_orig", "feat_32.5.weight", "feat_32.5.weight_u", "feat_32.5.weight_orig", "feat_32.5.weight_u", "feat_32.5.weight_v", "feat_32.6.weight", "feat_32.7.weight", "feat_32.7.bias", "feat_32.7.running_mean", "feat_32.7.running_var", "feat_64.1.weight_orig", "feat_64.1.weight", "feat_64.1.weight_u", "feat_64.1.weight_orig", "feat_64.1.weight_u", "feat_64.1.weight_v", "feat_64.2.weight", "feat_64.2.bias", "feat_64.2.running_mean", "feat_64.2.running_var", "feat_128.1.weight_orig", "feat_128.1.weight", "feat_128.1.weight_u", "feat_128.1.weight_orig", "feat_128.1.weight_u", "feat_128.1.weight_v", "feat_128.2.weight", "feat_128.3.weight", "feat_128.3.bias", "feat_128.3.running_mean", "feat_128.3.running_var", "feat_128.5.weight_orig", "feat_128.5.weight", "feat_128.5.weight_u", "feat_128.5.weight_orig", "feat_128.5.weight_u", "feat_128.5.weight_v", "feat_128.6.weight", "feat_128.7.weight", "feat_128.7.bias", "feat_128.7.running_mean", "feat_128.7.running_var", "feat_256.1.weight_orig", "feat_256.1.weight", "feat_256.1.weight_u", "feat_256.1.weight_orig", "feat_256.1.weight_u", "feat_256.1.weight_v", "feat_256.2.weight", "feat_256.2.bias", "feat_256.2.running_mean", "feat_256.2.running_var", "se_64.main.1.weight_orig", "se_64.main.1.weight", "se_64.main.1.weight_u", "se_64.main.1.weight_orig", "se_64.main.1.weight_u", "se_64.main.1.weight_v", "se_64.main.3.weight_orig", "se_64.main.3.weight", "se_64.main.3.weight_u", "se_64.main.3.weight_orig", "se_64.main.3.weight_u", "se_64.main.3.weight_v", "se_128.main.1.weight_orig", "se_128.main.1.weight", "se_128.main.1.weight_u", "se_128.main.1.weight_orig", "se_128.main.1.weight_u", "se_128.main.1.weight_v", "se_128.main.3.weight_orig", "se_128.main.3.weight", "se_128.main.3.weight_u", "se_128.main.3.weight_orig", "se_128.main.3.weight_u", "se_128.main.3.weight_v", "se_256.main.1.weight_orig", "se_256.main.1.weight", "se_256.main.1.weight_u", "se_256.main.1.weight_orig", "se_256.main.1.weight_u", "se_256.main.1.weight_v", "se_256.main.3.weight_orig", "se_256.main.3.weight", "se_256.main.3.weight_u", "se_256.main.3.weight_orig", "se_256.main.3.weight_u", "se_256.main.3.weight_v", "to_128.weight_orig", "to_128.weight", "to_128.weight_u", "to_128.weight_orig", "to_128.weight_u", "to_128.weight_v", "to_big.weight_orig", "to_big.weight", "to_big.weight_u", "to_big.weight_orig", "to_big.weight_u", "to_big.weight_v", "feat_512.1.weight_orig", "feat_512.1.weight", "feat_512.1.weight_u", "feat_512.1.weight_orig", "feat_512.1.weight_u", "feat_512.1.weight_v", "feat_512.2.weight", "feat_512.3.weight", "feat_512.3.bias", "feat_512.3.running_mean", "feat_512.3.running_var", "feat_512.5.weight_orig", "feat_512.5.weight", "feat_512.5.weight_u", "feat_512.5.weight_orig", "feat_512.5.weight_u", "feat_512.5.weight_v", "feat_512.6.weight", "feat_512.7.weight", "feat_512.7.bias", "feat_512.7.running_mean", "feat_512.7.running_var", "se_512.main.1.weight_orig", "se_512.main.1.weight", "se_512.main.1.weight_u", "se_512.main.1.weight_orig", "se_512.main.1.weight_u", "se_512.main.1.weight_v", "se_512.main.3.weight_orig", "se_512.main.3.weight", "se_512.main.3.weight_u", "se_512.main.3.weight_orig", "se_512.main.3.weight_u", "se_512.main.3.weight_v", "feat_1024.1.weight_orig", "feat_1024.1.weight", "feat_1024.1.weight_u", "feat_1024.1.weight_orig", "feat_1024.1.weight_u", "feat_1024.1.weight_v", "feat_1024.2.weight", "feat_1024.2.bias", "feat_1024.2.running_mean", "feat_1024.2.running_var". 
    Unexpected key(s) in state_dict: "module.init.init.0.weight_orig", "module.init.init.0.weight_u", "module.init.init.0.weight_v", "module.init.init.1.weight", "module.init.init.1.bias", "module.init.init.1.running_mean", "module.init.init.1.running_var", "module.init.init.1.num_batches_tracked", "module.feat_8.1.weight_orig", "module.feat_8.1.weight_u", "module.feat_8.1.weight_v", "module.feat_8.2.weight", "module.feat_8.3.weight", "module.feat_8.3.bias", "module.feat_8.3.running_mean", "module.feat_8.3.running_var", "module.feat_8.3.num_batches_tracked", "module.feat_8.5.weight_orig", "module.feat_8.5.weight_u", "module.feat_8.5.weight_v", "module.feat_8.6.weight", "module.feat_8.7.weight", "module.feat_8.7.bias", "module.feat_8.7.running_mean", "module.feat_8.7.running_var", "module.feat_8.7.num_batches_tracked", "module.feat_16.1.weight_orig", "module.feat_16.1.weight_u", "module.feat_16.1.weight_v", "module.feat_16.2.weight", "module.feat_16.2.bias", "module.feat_16.2.running_mean", "module.feat_16.2.running_var", "module.feat_16.2.num_batches_tracked", "module.feat_32.1.weight_orig", "module.feat_32.1.weight_u", "module.feat_32.1.weight_v", "module.feat_32.2.weight", "module.feat_32.3.weight", "module.feat_32.3.bias", "module.feat_32.3.running_mean", "module.feat_32.3.running_var", "module.feat_32.3.num_batches_tracked", "module.feat_32.5.weight_orig", "module.feat_32.5.weight_u", "module.feat_32.5.weight_v", "module.feat_32.6.weight", "module.feat_32.7.weight", "module.feat_32.7.bias", "module.feat_32.7.running_mean", "module.feat_32.7.running_var", "module.feat_32.7.num_batches_tracked", "module.feat_64.1.weight_orig", "module.feat_64.1.weight_u", "module.feat_64.1.weight_v", "module.feat_64.2.weight", "module.feat_64.2.bias", "module.feat_64.2.running_mean", "module.feat_64.2.running_var", "module.feat_64.2.num_batches_tracked", "module.feat_128.1.weight_orig", "module.feat_128.1.weight_u", "module.feat_128.1.weight_v", "module.feat_128.2.weight", "module.feat_128.3.weight", "module.feat_128.3.bias", "module.feat_128.3.running_mean", "module.feat_128.3.running_var", "module.feat_128.3.num_batches_tracked", "module.feat_128.5.weight_orig", "module.feat_128.5.weight_u", "module.feat_128.5.weight_v", "module.feat_128.6.weight", "module.feat_128.7.weight", "module.feat_128.7.bias", "module.feat_128.7.running_mean", "module.feat_128.7.running_var", "module.feat_128.7.num_batches_tracked", "module.feat_256.1.weight_orig", "module.feat_256.1.weight_u", "module.feat_256.1.weight_v", "module.feat_256.2.weight", "module.feat_256.2.bias", "module.feat_256.2.running_mean", "module.feat_256.2.running_var", "module.feat_256.2.num_batches_tracked", "module.se_64.main.1.weight_orig", "module.se_64.main.1.weight_u", "module.se_64.main.1.weight_v", "module.se_64.main.3.weight_orig", "module.se_64.main.3.weight_u", "module.se_64.main.3.weight_v", "module.se_128.main.1.weight_orig", "module.se_128.main.1.weight_u", "module.se_128.main.1.weight_v", "module.se_128.main.3.weight_orig", "module.se_128.main.3.weight_u", "module.se_128.main.3.weight_v", "module.se_256.main.1.weight_orig", "module.se_256.main.1.weight_u", "module.se_256.main.1.weight_v", "module.se_256.main.3.weight_orig", "module.se_256.main.3.weight_u", "module.se_256.main.3.weight_v", "module.to_128.weight_orig", "module.to_128.weight_u", "module.to_128.weight_v", "module.to_big.weight_orig", "module.to_big.weight_u", "module.to_big.weight_v", "module.feat_512.1.weight_orig", "module.feat_512.1.weight_u", "module.feat_512.1.weight_v", "module.feat_512.2.weight", "module.feat_512.3.weight", "module.feat_512.3.bias", "module.feat_512.3.running_mean", "module.feat_512.3.running_var", "module.feat_512.3.num_batches_tracked", "module.feat_512.5.weight_orig", "module.feat_512.5.weight_u", "module.feat_512.5.weight_v", "module.feat_512.6.weight", "module.feat_512.7.weight", "module.feat_512.7.bias", "module.feat_512.7.running_mean", "module.feat_512.7.running_var", "module.feat_512.7.num_batches_tracked", "module.se_512.main.1.weight_orig", "module.se_512.main.1.weight_u", "module.se_512.main.1.weight_v", "module.se_512.main.3.weight_orig", "module.se_512.main.3.weight_u", "module.se_512.main.3.weight_v", "module.feat_1024.1.weight_orig", "module.feat_1024.1.weight_u", "module.feat_1024.1.weight_v", "module.feat_1024.2.weight", "module.feat_1024.2.bias", "module.feat_1024.2.running_mean", "module.feat_1024.2.running_var", "module.feat_1024.2.num_batches_tracked". 

Any idea what might be happening here?

Thanks so much for any help!

KT9000
  • 35
  • 6
  • I recommend you lookup the issues on that Github Repo or raise this issue on the Repo, as the Repo Owner would be more likely to come up with a solution than the folks on StackOverflow regarding this – The Singularity Oct 07 '21 at 09:17
  • I've read up on the issues already but was hesitant to bother the repo owner. I'm not familiar with python/pytorch so I tend to assume it's just a simple mistake on my end. Might try that though, thanks! – KT9000 Oct 07 '21 at 09:25
  • My apologies, but I doubt you'll receive an answer for this here – The Singularity Oct 07 '21 at 09:27

1 Answers1

1

This is common when changing the attribute name of the submodules in your nn.Module.

Notice how most of your layer keys here differ from the ones contained in the loaded state dict because of their prefix: all keys in the dictionary have a 'module.' prefix.

A quick fix you be to slice away this prefix. You could for instance usea dict comprehension:

loaded_state = {k.replace('module.', ''): v for k, v in ckpt['g'].items()}
netG.load_state_dict(loaded_state)
Ivan
  • 34,531
  • 8
  • 55
  • 100
  • 1
    Thanks for the fix! Just tried it but it seems like the same has to be done for the discriminator keys.Is this the best method to fix the issue? Could it have to do something with the model being saved using nn.DataParallel? I see that multi_gpu is set to True by default. – KT9000 Oct 07 '21 at 11:27
  • 1
    You are right, `nn.DataParallel` will effectively wrap you `nn.Module` which results in `module` being prepended to your original modules's content. – Ivan Oct 07 '21 at 11:50
  • So I assume your fix would be the easiest way to revert that? I've tried applying it to both g and d keys but it moves on to another error (just a wrong order?): `File "train.py", line 127, in train optimizerG.load_state_dict(ckpt['opt_g']) UnboundLocalError: local variable 'optimizerG' referenced before assignment` – KT9000 Oct 07 '21 at 12:04