I'm trying to optimize two models in an alternating fashion using PyTorch. The first is a neural network that is changing the representation of my data (ie a map f(x) on my input data x, parameterized by some weights W). The second is a Gaussian mixture model that is operating on the f(x) points, ie in the neural network space (rather than clustering points in the input space. I am optimizing the GMM using expectation maximization, so the parameter updates are analytically derived, rather than using gradient descent.
I have two loss functions here: the first is a function of the distances ||f(x) - f(y)||, and the second is the loss function of the Gaussian mixture model (ie how 'clustered' everything looks in the NN representation space). What I want to do is take a step in the NN optimization using both of the above loss functions (since it depends on both), and then do an expectation-maximization step for the GMM. The code looks like this (I have removed a lot since there is a ton of code):
data, labels = load_dataset()
net = NeuralNetwork()
net_optim = torch.optim.Adam(net.parameters(), lr=0.05, weight_decay=1)
# initialize weights, means, and covariances for the Gaussian clusters
concentrations, means, covariances, precisions = initialization(net.forward_one(data))
for i in range(1000):
net_optim.zero_grad()
pairs, pair_labels = pairGenerator(data, labels) # samples some pairs of datapoints
outputs = net(pairs[:, 0, :], pairs[:, 1, :]) # computes pairwise distances
net_loss = NeuralNetworkLoss(outputs, pair_labels) # loss function based on pairwise dist.
embedding = net.forward_one(data) # embeds all data in the NN space
log_prob, log_likelihoods = expectation_step(embedding, means, precisions, concentrations)
concentrations, means, covariances, precisions = maximization_step(embedding, log_likelihoods)
gmm_loss = GMMLoss(log_likelihoods, log_prob, precisions, concentrations)
net_loss.backward(retain_graph=True)
gmm_loss.backward(retain_graph=True)
net_optim.step()
Essentially, this is what is happening:
- Sample some pairs of points from the dataset
- Push pairs of points through the NN and compute network loss based on those outputs
- Embed all datapoints using the NN and perform a clustering EM step in that embedding space
- Compute variational loss (ELBO) based on clustering parameters
- Update neural network parameters using both the variational loss and the network loss
However, to perform (5), I am required to add the flag retain_graph=True
, otherwise I get the error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
It seems like having two loss functions means that I need to retain the computational graph?
I am not sure how to work around this, as with retain_graph=True
, around iteration 400, each iteration is taking ~30 minutes to complete. Does anyone know how I might fix this? I apologize in advance – I am still very new to automatic differentiation.