I've two networks, which I need to concatenate for my full model. However my first model is pre-trained and I need to make it non-trainable when training the full model. How can I achieve this in PyTorch.
I am able to concatenate two models using this answer
class MyModelA(nn.Module):
def __init__(self):
super(MyModelA, self).__init__()
self.fc1 = nn.Linear(10, 2)
def forward(self, x):
x = self.fc1(x)
return x
class MyModelB(nn.Module):
def __init__(self):
super(MyModelB, self).__init__()
self.fc1 = nn.Linear(20, 2)
def forward(self, x):
x = self.fc1(x)
return x
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB):
super(MyEnsemble, self).__init__()
self.modelA = modelA
self.modelB = modelB
def forward(self, x):
x1 = self.modelA(x)
x2 = self.modelB(x1)
return x2
# Create models and load state_dicts
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))
model = MyEnsemble(modelA, modelB)
x = torch.randn(1, 10)
output = model(x)
Basically here, I want to load pre-trained modelA
and make it non-trainable when training the Ensemble model.