nn.DataParallel
wraps the model, where the actual model is assigned to the module
attribute. That also means that the keys in the state dict have a module.
prefix.
Let's look at a very simplified version with just one convolution to see the difference:
class NestedUNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
model = NestedUNet()
model.state_dict().keys() # => odict_keys(['conv1.weight', 'conv1.bias'])
# Wrap the model in DataParallel
model_dp = nn.DataParallel(model, device_ids=range(num_gpus))
model_dp.state_dict().keys() # => odict_keys(['module.conv1.weight', 'module.conv1.bias'])
The state dict you saved with nn.DataParallel
does not line up with the regular model's state. You are merging the current state dict with the loaded state dict, that means that the loaded state is ignored, because the model does not have any attributes that belong to the keys and instead you are left with the randomly initialised model.
To avoid making that mistake, you shouldn't merge the state dicts, but rather directly apply it to the model, in which case there will be an error if the keys don't match.
RuntimeError: Error(s) in loading state_dict for NestedUNet:
Missing key(s) in state_dict: "conv1.weight", "conv1.bias".
Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias".
To make the state dict that you have saved compatible, you can strip off the module.
prefix:
pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
model.load_state_dict(pretrained_dict)
You can also avoid this issue in the future by unwrapping the model from nn.DataParallel
before saving its state, i.e. saving model.module.state_dict()
. So you can always load the model first with its state and then later decide to put it into nn.DataParallel
if you wanted to use multiple GPUs.