I saved a nn.Module
model using (logically):
model = MyWeirdModel()
model.patched_features = .....
train(model)
torch.save(model, file)
Ideally, one would load this model using
model = torch.load(file)
However, in my case this doesn't work because Pickle uses the static class definition when un-pickling, so I get AttributeError: 'MyWeirdModel' object has no attribute 'patched_features'
(this attribute was added to MyWeirdModel at runtime).
I would like to avoid having to re-train the model, so I don't want to change the code for saving, only loading.
# Initialise the model in the same way as before
model = MyWeirdModel()
model.patched_features = .....
state_dict = load_state_dict_only(file) # How does one do this?
model.load_state_dict(state_dict)
My understanding is that torch.save()
saves the model AND the state dict. How do I load only the state dict from the pickled model, such that I can recover the model?