First, regarding the model you provided, you can immediately call the model's encoder. However, sometimes it's not that easy and you might have complicated architecture including several layers such as linear layers, CCNs, RNNs, a Vector Quantizer, etc. In that case, I usually define two additional functions called forward_encoder and forward_decoder especially when I want to play with the latent space such as representation learning for clustering, visualization, and data compression.
class Autoencoder(nn.Module):
def __init__(self, n_features):
super(Autoencoder, self).__init__()
self.n_features = n_features
self.encoder = nn.Sequential(
nn.Linear(self.n_features, 1),
nn.ReLU(True))
self.decoder = nn.Sequential(
nn.Linear(1, self.n_features),
nn.ReLU(True))
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
def forward_encoder(self, x):
latent = self.encoder(x)
return latent
def forward_decoder(self, latent):
x = self.decoder(latent)
return x
You could also modify the forward function and use a flag to whether return the latent variable or the reconstructed input.
def forward(self, x, get_latent=Flase):
latent = self.encoder(x)
x = self.decoder(latent)
if get_latent:
return latent:
else:
return x
Note that if you care about the latent space in the middle of the network rather than the output quality there are better options such as Denoising Autoencoder (DAE), Variational Autoencoder (VAE), Vector-Quantized Atuoencoder (VQ-VAE), etc. instead of Vanilla Autoencoder. Check this answer where I compared Vanilla Autoencoder and VAE. Note that there are other parameters that play role in the learned latent space such as parameters initialization, network architecture/complexity, activation functions, optimizer and so forth.
After completing the training, it is necessary to save the weights for future use (check this page for more details). Then, you need to create an instance of the model and load the saved weights into it for performing inference including the entire model or a specific component of it such as the encoder.
For example train the model:
# Create an instance of the model
model = Autoencoder(n_features=10)
# train the model
Train(model, epochs, dataloader, ...)
# save the model parameters
torch.save(model.state_dict(), Save_Path)
At the inference time:
model = Autoencoder(n_features=10)
model.load_state_dict(torch.load(Save_Path))
#set the model to evaluation mode
model.eval()
sample_x = torch.zeros(16, 10)
out1 = model.encoder(sample_x)
out2 = model.forward_encoder(sample_x)
out3 = model(sample_x, get_latent=True)
where all out1, out2, and out3 would be the same.