2

When loading a model which was saved from a model in distributed mode, the model names are different, resulting in this error. How can I resolve this?

  File "/code/src/bert_structure_prediction/model.py", line 36, in __init__                         
    self.load_state_dict(state_dict)                                                                
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1223, in load_state
_dict                                                                                               
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(                       
RuntimeError: Error(s) in loading state_dict for BertCoordinatePredictor:                           
        Missing key(s) in state_dict: "bert.embeddings.position_ids", "bert.embeddings.word_embeddin
gs.weight", ...etc.
Jacob Stern
  • 3,758
  • 3
  • 32
  • 54
  • There are some other pretty good answers to this question at https://stackoverflow.com/questions/75724281/runtimeerror-errors-in-loading-state-dict/75728573#75728573 – DerekG Aug 04 '23 at 13:14

2 Answers2

1

The reason why the model names don't match is because DDP wraps the model object, resulting in different layer names when saving the model in distributed data parallel mode (specifically, layer names will have module. prepended to the model name). To resolve this, use

torch.save(model.module.state_dict(), PATH)

instead of

torch.save(model.state_dict(), PATH)

when saving from data parallel.

Jacob Stern
  • 3,758
  • 3
  • 32
  • 54
1

You can set the strict argument to False in the load_state_dict() function to ignore non-matching keys:

model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)
Adrian Mole
  • 49,934
  • 160
  • 51
  • 83
Alpcan
  • 11
  • 5