4

I saved a nn.Module model using (logically):

model = MyWeirdModel()
model.patched_features = .....
train(model)
torch.save(model, file)

Ideally, one would load this model using

model = torch.load(file)

However, in my case this doesn't work because Pickle uses the static class definition when un-pickling, so I get AttributeError: 'MyWeirdModel' object has no attribute 'patched_features' (this attribute was added to MyWeirdModel at runtime).

I would like to avoid having to re-train the model, so I don't want to change the code for saving, only loading.

# Initialise the model in the same way as before
model = MyWeirdModel()
model.patched_features = .....

state_dict = load_state_dict_only(file) # How does one do this?

model.load_state_dict(state_dict)

My understanding is that torch.save() saves the model AND the state dict. How do I load only the state dict from the pickled model, such that I can recover the model?

mxbi
  • 853
  • 6
  • 25
  • Hey, did you figure out how to solve it? Have the exact same problem right now – Dion Jul 19 '22 at 09:25
  • I never did sadly @Dion I may place a bounty on this. – mxbi Sep 14 '22 at 22:24
  • what exactly is the `patched_features`? – KonstantinosKokos Sep 16 '22 at 20:44
  • 1
    Linking https://stackoverflow.com/questions/50465106/attributeerror-when-reading-a-pickle-file which lists several possible solutions for deserializing a python object when the "__ main __" module has changed, most notably with custom deserializer. This seems to be the main path forward given that a.) you know not to save object/ function instances in pickles, but b.) you already have done so and need a one-time solution to retrieve this data – DerekG Sep 19 '22 at 01:27

1 Answers1

0

Referring to the documentation,

only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s state_dict.

Your patched_features is not stored in the model’s state_dict, but you can do so by registering it with register_parameter as pointed out by ptrblck in this post but it may involve retraining.

Angus
  • 3,680
  • 1
  • 12
  • 27