I'm currently working on a variation of Variational Autoencoder in a sequential setting, where the task is to fit/recover a sequence of real-valued observation data (hence it is a regression problem).
I have built my model using tf.keras
with eager execution enabled, and tensorflow_probability (tfp). Following VAE concept, the generative net emits the distribution parameters of the observation data, which I model as multivariate normal. Therefore the outputs are mean and logvar of the predicted distribution.
Regarding training process, the first component of the loss is reconstruction error. That is the log likelihood of the true observation, given the predicted (parameters) distribution from the generative net. Here, I use tfp.distributions
, since it is fast and handy.
However, after training is done, marked by a considerably low loss value, it turns out that my model seems not to learn anything. The predicted value from the model is just barely flat across the time dimension (recall that the problem is sequential).
Nevertheless, for the sake of sanity check, when I replace log likelihood with MSE loss (which is not justifiable while working on VAE), it yields very good data fitting. So I conclude that there must be something wrong with this log likelihood term. Is there anyone having some clue and/or solution for this?
I have considered replacing the log likelihood with cross-entropy loss, but I think that is not applicable in my case, since my problem is regression and the data can't be normalized into [0,1] range.
I also have tried to implement annealed KL term (i.e. weighing the KL term with constant < 1) when using the log likelihood as the reconstruction loss. But it also didn't work.
Here is my code snippet of the original (using log likelihood as reconstruction error) loss function:
import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()
import tensorflow_probability as tfp
tfd = tfp.distributions
def loss(model, inputs):
outputs, _ = SSM_model(model, inputs)
#allocate the corresponding output component
infer_mean = outputs[:,:,:latent_dim] #mean of latent variable from inference net
infer_logvar = outputs[:,:,latent_dim : (2 * latent_dim)]
trans_mean = outputs[:,:,(2 * latent_dim):(3 * latent_dim)] #mean of latent variable from transition net
trans_logvar = outputs[:,:, (3 * latent_dim):(4 * latent_dim)]
obs_mean = outputs[:,:,(4 * latent_dim):((4 * latent_dim) + output_obs_dim)] #mean of observation from generative net
obs_logvar = outputs[:,:,((4 * latent_dim) + output_obs_dim):]
target = inputs[:,:,2:4]
#transform logvar to std
infer_std = tf.sqrt(tf.exp(infer_logvar))
trans_std = tf.sqrt(tf.exp(trans_logvar))
obs_std = tf.sqrt(tf.exp(obs_logvar))
#computing loss at each time step
time_step_loss = []
for i in range(tf.shape(outputs)[0].numpy()):
#distribution of each module
infer_dist = tfd.MultivariateNormalDiag(infer_mean[i],infer_std[i])
trans_dist = tfd.MultivariateNormalDiag(trans_mean[i],trans_std[i])
obs_dist = tfd.MultivariateNormalDiag(obs_mean[i],obs_std[i])
#log likelihood of observation
likelihood = obs_dist.prob(target[i]) #shape = 1D = batch_size
likelihood = tf.clip_by_value(likelihood, 1e-37, 1)
log_likelihood = tf.log(likelihood)
#KL of (q|p)
kl = tfd.kl_divergence(infer_dist, trans_dist) #shape = batch_size
#the loss
loss = - log_likelihood + kl
time_step_loss.append(loss)
time_step_loss = tf.convert_to_tensor(time_step_loss)
overall_loss = tf.reduce_sum(time_step_loss)
overall_loss = tf.cast(overall_loss, dtype='float32')
return overall_loss