4

Train on the GPU, num_gpus is set to 1:

device_ids = list(range(num_gpus))
model = NestedUNet(opt.num_channel, 2).to(device)
model = nn.DataParallel(model, device_ids=device_ids)

Test on the CPU:

model = NestedUNet_Purn2(opt.num_channel, 2).to(dev)
device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)
model_old = torch.load(path, map_location=dev)
pretrained_dict = model_old.state_dict()
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

This will get the correct result, but when I delete:

device_ids = list(range(num_gpus))
model = torch.nn.DataParallel(model, device_ids=device_ids)

the result is wrong.

likyoo
  • 45
  • 1
  • 4

2 Answers2

12

nn.DataParallel wraps the model, where the actual model is assigned to the module attribute. That also means that the keys in the state dict have a module. prefix.

Let's look at a very simplified version with just one convolution to see the difference:

class NestedUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)

model = NestedUNet()

model.state_dict().keys() # => odict_keys(['conv1.weight', 'conv1.bias'])

# Wrap the model in DataParallel
model_dp = nn.DataParallel(model, device_ids=range(num_gpus))

model_dp.state_dict().keys() # => odict_keys(['module.conv1.weight', 'module.conv1.bias'])

The state dict you saved with nn.DataParallel does not line up with the regular model's state. You are merging the current state dict with the loaded state dict, that means that the loaded state is ignored, because the model does not have any attributes that belong to the keys and instead you are left with the randomly initialised model.

To avoid making that mistake, you shouldn't merge the state dicts, but rather directly apply it to the model, in which case there will be an error if the keys don't match.

RuntimeError: Error(s) in loading state_dict for NestedUNet:
        Missing key(s) in state_dict: "conv1.weight", "conv1.bias".
        Unexpected key(s) in state_dict: "module.conv1.weight", "module.conv1.bias".

To make the state dict that you have saved compatible, you can strip off the module. prefix:

pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
model.load_state_dict(pretrained_dict)

You can also avoid this issue in the future by unwrapping the model from nn.DataParallel before saving its state, i.e. saving model.module.state_dict(). So you can always load the model first with its state and then later decide to put it into nn.DataParallel if you wanted to use multiple GPUs.

Michael Jungo
  • 31,583
  • 3
  • 91
  • 84
5

You trained your model using DataParallel and saved it. So, the model weights were stored with a module. prefix. Now, when you load without DataParallel, you basically are not loading any model weights (the model has random weights). As a result, the model predictions are wrong.

I am giving an example.

model = nn.Linear(2, 4)
model = torch.nn.DataParallel(model, device_ids=device_ids)
model.state_dict().keys() # => odict_keys(['module.weight', 'module.bias'])

On the other hand,

another_model = nn.Linear(2, 4)
another_model.state_dict().keys() # => odict_keys(['weight', 'bias'])

See the difference in the OrderedDict keys.

So, in your code, the following three-line works but no model weights are loaded.

pretrained_dict = model_old.state_dict()
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

Here, model_dict has keys without the module. prefix but pretrained_dict has when you do not use DataParalle. So, essentially pretrained_dict is empty when DataParallel is not used.


Solution: If you want to avoid using DataParallel, or you can load the weights file, create a new OrderedDict without the module prefix, and load it back.

Something like the following would work for your case without using DataParallel.

# original saved file with DataParallel
model_old = torch.load(path, map_location=dev)

# create new OrderedDict that does not contain `module.`
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in model_old.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v

# load params
model.load_state_dict(new_state_dict)
Wasi Ahmad
  • 35,739
  • 32
  • 114
  • 161