0

I am trying to get better results by allowing a few final layers of a previously frozen backbone (RegNet-800MF) to be trained. How can I implement this in PyTorch Lightning? I am very new to ML so please excuse me if I have left any important information out.

My model (MechClassifier) calls another class (ParametersClassifier) which includes the pre-trained RegNet as its frozen backbone. During training the forward function passes inputs only through the backbone of the ParametersClassifier and not the Classifying layers. I will include the init functions of both below.

My MechClassifier model:

class MechClassifier(pl.LightningModule):
    def __init__(
        self,
        num_classes,
        lr=4e-3,
        weight_decay=1e-8,
        gpus=1,
        max_epochs=30,
    ):
        super().__init__()
        self.lr = lr
        self.weight_decay = weight_decay
        self.__dict__.update(locals())
        
        self.backbone = ParametersClassifier.load_from_checkpoint(
            checkpoint_path="checkpoints/param_classifier/last.ckpt",
            num_classes=3,
            gpus=1,
        )
        
        self.backbone.freeze()
        self.backbone.eval()


        self.mf_classifier = nn.Sequential(
            nn.Linear(self.backbone.num_ftrs, 8),
            nn.ReLU(),
            nn.Linear(8, num_classes),
        )
        
        self.wd_classifier = nn.Sequential(
            nn.Linear(self.backbone.num_ftrs, 8),
            nn.ReLU(),
            nn.Linear(8, num_classes),
        )

    def forward(self, x):
        self.backbone.eval()
        with torch.no_grad():
            x = self.backbone.model(x)

        # x = self.model(x)

        out1 = self.mf_classifier(x)
        out2 = self.wd_classifier(x)

        # print(out1.size())
        return (out1, out2)

ParametersClassifier (loaded from checkpoint):

class ParametersClassifier(pl.LightningModule):
    def __init__(
        self,
        num_classes,
        lr=4e-3,
        weight_decay=0.05,
        gpus=1,
        max_epochs=30,
    ):
        super().__init__()
        self.lr = lr
        self.weight_decay = weight_decay
        self.__dict__.update(locals())

        self.model = models.regnet_y_800mf(pretrained=True)
        self.num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Identity()
        self.fc1 = nn.Linear(self.num_ftrs, num_classes)
        self.fc2 = nn.Linear(self.num_ftrs, num_classes)
        self.fc3 = nn.Linear(self.num_ftrs, num_classes)
        self.fc4 = nn.Linear(self.num_ftrs, num_classes)

    def forward(self, x):
        x = self.model(x)
        out1 = self.fc1(x)
        out2 = self.fc2(x)
        out3 = self.fc3(x)
        out4 = self.fc4(x)
        return (out1, out2, out3, out4)
Ivan
  • 34,531
  • 8
  • 55
  • 100
tom_walkr
  • 1
  • 1
  • So you are looking to only train `mf_classifier` and `wd_classifier`, correct? – Ivan Jun 22 '22 at 14:06
  • @Ivan Previously, yes I have only been training mf_classifier and wd_classifier. However now I would like to train mf_classifier and wd_classifier, plus the final few layers of the backbone (RegNet within the ParametersClassifer class). – tom_walkr Jun 22 '22 at 16:11
  • Ok, can you show the `forward` function of `ParametersClassifier`? – Ivan Jun 22 '22 at 19:53
  • @Ivan - have edited the question to show before forward steps. – tom_walkr Jun 23 '22 at 11:26

1 Answers1

0

You can look at the implementation for the Regnet model you are using here. Its forward function:

def forward(self, x: Tensor) -> Tensor:
    x = self.stem(x)
    x = self.trunk_output(x)

    x = self.avgpool(x)
    x = x.flatten(start_dim=1)
    x = self.fc(x)

    return x

Instead of using a torch.no_grad context manager as you did, you should rather switch on/off the requires_grad as necessary. By default module parameters have their requires_grad flag set to True which means they are able to perform gradient computation. If this flag is set to False, you can consider those components as frozen.

Depending on which layers you want to freeze and those that you want to finetune, you can manually do that. For example, if you want to freeze the backbone and finetune the fully connected layer of the Regnet, and replace the following from MechClassifier's __init__:

self.backbone.freeze()
self.backbone.eval()

With the following lines:

## freeze all
self.backbone.model.requires_grad_(False)

## unfreeze last section of 4th block of backbone 
block4_section1 = getattr(self.backbone.model.trunk_output.block4, 'block4-1')
block4_section1.requires_grad_(True)

And perform inference on MechClassifier with a forward function like so:

def forward(self, x):
    self.backbone.eval()
    x = self.backbone.model(x)
    out1 = self.mf_classifier(x)
    out2 = self.wd_classifier(x)
    return (out1, out2)
Ivan
  • 34,531
  • 8
  • 55
  • 100