3

I have the following PyTorch model:

import math
from abc import abstractmethod

import torch.nn as nn


class AlexNet3D(nn.Module):
    @abstractmethod
    def get_head(self):
        pass

    def __init__(self, input_size):
        super().__init__()
        self.input_size = input_size
        self.features = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=0),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=3, stride=3),

            nn.Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=0),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=3, stride=3),

            nn.Conv3d(128, 192, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(192),
            nn.ReLU(inplace=True),

            nn.Conv3d(192, 192, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(192),
            nn.ReLU(inplace=True),

            nn.Conv3d(192, 128, kernel_size=(3, 3, 3), padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=3, stride=3),
        )

        self.classifier = self.get_head()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        xp = self.features(x)
        x = xp.view(xp.size(0), -1)
        x = self.classifier(x)
        return [x, xp]


class AlexNet3DDropoutRegression(AlexNet3D):
    def get_head(self):
        return nn.Sequential(nn.Dropout(),
                             nn.Linear(self.input_size, 64),
                             nn.ReLU(inplace=True),
                             nn.Dropout(),
                             nn.Linear(64, 1),
                             )

I am initializing the model like this:

def init_model(self):
    model = AlexNet3DDropoutRegression(4608)
    if self.use_cuda:
        log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model = model.to(self.device)
    return model

After training, I save the model like this:

    torch.save(self.model.state_dict(), self.cli_args.model_save_location)

I then attempt to load the saved model:

import torch
from reprex.models import AlexNet3DDropoutRegression


model_save_location = "/home/feczk001/shared/data/AlexNet/LoesScoring/loes_scoring_01.pt"

model = AlexNet3DDropoutRegression(4608)
model.load_state_dict(torch.load(model_save_location,
                                 map_location='cpu'))

But I get the following error:

RuntimeError: Error(s) in loading state_dict for AlexNet3DDropoutRegression:
    Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.4.weight", "features.4.bias", "features.5.weight", "features.5.bias", "features.5.running_mean", "features.5.running_var", "features.8.weight", "features.8.bias", "features.9.weight", "features.9.bias", "features.9.running_mean", "features.9.running_var", "features.11.weight", "features.11.bias", "features.12.weight", "features.12.bias", "features.12.running_mean", "features.12.running_var", "features.14.weight", "features.14.bias", "features.15.weight", "features.15.bias", "features.15.running_mean", "features.15.running_var", "classifier.1.weight", "classifier.1.bias", "classifier.4.weight", "classifier.4.bias". 
    Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.1.weight", "module.features.1.bias", "module.features.1.running_mean", "module.features.1.running_var", "module.features.1.num_batches_tracked", "module.features.4.weight", "module.features.4.bias", "module.features.5.weight", "module.features.5.bias", "module.features.5.running_mean", "module.features.5.running_var", "module.features.5.num_batches_tracked", "module.features.8.weight", "module.features.8.bias", "module.features.9.weight", "module.features.9.bias", "module.features.9.running_mean", "module.features.9.running_var", "module.features.9.num_batches_tracked", "module.features.11.weight", "module.features.11.bias", "module.features.12.weight", "module.features.12.bias", "module.features.12.running_mean", "module.features.12.running_var", "module.features.12.num_batches_tracked", "module.features.14.weight", "module.features.14.bias", "module.features.15.weight", "module.features.15.bias", "module.features.15.running_mean", "module.features.15.running_var", "module.features.15.num_batches_tracked", "module.classifier.1.weight", "module.classifier.1.bias", "module.classifier.4.weight", "module.classifier.4.bias". 

What is going wrong here?

Paul Reiners
  • 8,576
  • 33
  • 117
  • 202

3 Answers3

6

The issue is that you train the model using DataParallel, and then attempt to reload the model in a non-parallel network. DataParallel is a wrapper class that makes the original model (a torch.nn.module object) a class attribute of the DataParallel object named module. This issue is addressed on pytorch discuss, stack overflow and github so I won't rehash the details here as well, but you can fix this by either:

  1. Saving and loading the model exclusively as a DataParallel object, which will likely cease to be effective when you want to use the model for inference, or

  2. Save the DataParallel object's module state_dict instead as such:

    # save state dict of DataParallel object
    torch.save(model.module.state_dict(), path)
    
    
     .... Later
    # reload weights on non-parallel model
    model.load_state_dict(torch.load(path)
    

Here's a trivial example:

model = AlexNet3DDropoutRegression(4608) # on cpu
model = nn.DataParallel(model)
model = model.to("cuda") # DataParallel object on GPU(s)


torch.save(model.module.state_dict(),"example_path.pt")

del model
model = AlexNet3DDropoutRegression(4608)

ret = model.load_state_dict(torch.load("example_path.pt")) 
print(ret) 

Output:

>>> <All keys successfully matched>
  1. Alternatively and perhaps more usefully if you already have a saved state_dict you need to reload, you can also load the state_dict for the DataParallel model, remap the key names to exclude "module", and then use the re-keyed state_dict. Something like:
incompatible_state_dict = torch.load("DataParallel_save_file.pt")
state_dict = {}
for key in incompatible_state_dict():
    state_dict[key.split("module.")[-1]] = incompatible_state_dict[key]

 ret = model.load_state_dict(state_dict)
 print(ret)

Output:

>>> <All keys successfully matched>
DerekG
  • 3,555
  • 1
  • 11
  • 21
  • I'm getting a warning of "unexpected argument" for the `map_location` argument of `torch.save()`. – Paul Reiners Mar 15 '23 at 13:25
  • Try removing that argument and see if it gives you an error. I believe you shouldn't need it. – DerekG Mar 15 '23 at 13:39
  • I tried removing that argument and I was back in the same situation I was in before. – Paul Reiners Mar 15 '23 at 18:21
  • 2
    I was able to save and load the model weights using your above code and the additional lines listed in this answer. The critical bit is that if your model is wrapped in a `DataParallel` object, you need to use `model.module.state_dict()` to access the parameters, and if not you simply do `model.state_dict()`. So depending on whether you load and save the checkpoint as a single `nn.module` or as a `DataParallel` object, you'll need to use the appropriate attribute to access the model weights. – DerekG Mar 15 '23 at 19:08
3

nn.DataParallel is a wrapper class, it adds a "module." prefix to all the keys in the state dictionary. Therefore, you see module.features and module.classifier in the unexpected keys. To solve this problem, all you need to do is to remove the module. prefix when loading the model state_dict.

model = AlexNet3DDropoutRegression(4608)
model_save_location = "/home/feczk001/shared/data/AlexNet/LoesScoring/loes_scoring_01.pt"

state_dict = torch.load(model_save_location, map_location='cpu')
model.load_state_dict({k.replace("module.", ""): v for k, v in state_dict.items()})
Hamzah
  • 8,175
  • 3
  • 19
  • 43
0

Your issue is that you are loading a state dictionary from an already trained DataParallel model and then you create a new one that does not use DataParallel. module is already prefixed when using DataParallel and PyTorch. So if you remove the module prefix, you will be fine. Unless you want to use DataParallel for the new model initialization as, you are better off just removing the module prefix.

This snippet should do it:

model = AlexNet3DDropoutRegression(4608)
state_dict = torch.load(model_save_location, map_location='cpu')
new_state_dict = {}
for key in state_dict.keys():
    new_key = key.replace("module.", "")
    new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)
Tendekai Muchenje
  • 440
  • 1
  • 6
  • 20