I am trying to get better results by allowing a few final layers of a previously frozen backbone (RegNet-800MF) to be trained. How can I implement this in PyTorch Lightning? I am very new to ML so please excuse me if I have left any important information out.
My model (MechClassifier) calls another class (ParametersClassifier) which includes the pre-trained RegNet as its frozen backbone. During training the forward function passes inputs only through the backbone of the ParametersClassifier and not the Classifying layers. I will include the init functions of both below.
My MechClassifier
model:
class MechClassifier(pl.LightningModule):
def __init__(
self,
num_classes,
lr=4e-3,
weight_decay=1e-8,
gpus=1,
max_epochs=30,
):
super().__init__()
self.lr = lr
self.weight_decay = weight_decay
self.__dict__.update(locals())
self.backbone = ParametersClassifier.load_from_checkpoint(
checkpoint_path="checkpoints/param_classifier/last.ckpt",
num_classes=3,
gpus=1,
)
self.backbone.freeze()
self.backbone.eval()
self.mf_classifier = nn.Sequential(
nn.Linear(self.backbone.num_ftrs, 8),
nn.ReLU(),
nn.Linear(8, num_classes),
)
self.wd_classifier = nn.Sequential(
nn.Linear(self.backbone.num_ftrs, 8),
nn.ReLU(),
nn.Linear(8, num_classes),
)
def forward(self, x):
self.backbone.eval()
with torch.no_grad():
x = self.backbone.model(x)
# x = self.model(x)
out1 = self.mf_classifier(x)
out2 = self.wd_classifier(x)
# print(out1.size())
return (out1, out2)
ParametersClassifier
(loaded from checkpoint):
class ParametersClassifier(pl.LightningModule):
def __init__(
self,
num_classes,
lr=4e-3,
weight_decay=0.05,
gpus=1,
max_epochs=30,
):
super().__init__()
self.lr = lr
self.weight_decay = weight_decay
self.__dict__.update(locals())
self.model = models.regnet_y_800mf(pretrained=True)
self.num_ftrs = self.model.fc.in_features
self.model.fc = nn.Identity()
self.fc1 = nn.Linear(self.num_ftrs, num_classes)
self.fc2 = nn.Linear(self.num_ftrs, num_classes)
self.fc3 = nn.Linear(self.num_ftrs, num_classes)
self.fc4 = nn.Linear(self.num_ftrs, num_classes)
def forward(self, x):
x = self.model(x)
out1 = self.fc1(x)
out2 = self.fc2(x)
out3 = self.fc3(x)
out4 = self.fc4(x)
return (out1, out2, out3, out4)