2

I'm currently working on a variation of Variational Autoencoder in a sequential setting, where the task is to fit/recover a sequence of real-valued observation data (hence it is a regression problem).

I have built my model using tf.keras with eager execution enabled, and tensorflow_probability (tfp). Following VAE concept, the generative net emits the distribution parameters of the observation data, which I model as multivariate normal. Therefore the outputs are mean and logvar of the predicted distribution.

Regarding training process, the first component of the loss is reconstruction error. That is the log likelihood of the true observation, given the predicted (parameters) distribution from the generative net. Here, I use tfp.distributions, since it is fast and handy.

However, after training is done, marked by a considerably low loss value, it turns out that my model seems not to learn anything. The predicted value from the model is just barely flat across the time dimension (recall that the problem is sequential).

Nevertheless, for the sake of sanity check, when I replace log likelihood with MSE loss (which is not justifiable while working on VAE), it yields very good data fitting. So I conclude that there must be something wrong with this log likelihood term. Is there anyone having some clue and/or solution for this?

I have considered replacing the log likelihood with cross-entropy loss, but I think that is not applicable in my case, since my problem is regression and the data can't be normalized into [0,1] range.

I also have tried to implement annealed KL term (i.e. weighing the KL term with constant < 1) when using the log likelihood as the reconstruction loss. But it also didn't work.

Here is my code snippet of the original (using log likelihood as reconstruction error) loss function:

    import tensorflow as tf
    tfe = tf.contrib.eager
    tf.enable_eager_execution()

    import tensorflow_probability as tfp
    tfd = tfp.distributions

    def loss(model, inputs):
        outputs, _ = SSM_model(model, inputs)

        #allocate the corresponding output component
        infer_mean = outputs[:,:,:latent_dim]  #mean of latent variable from  inference net
        infer_logvar = outputs[:,:,latent_dim : (2 * latent_dim)]
        trans_mean = outputs[:,:,(2 * latent_dim):(3 * latent_dim)] #mean of latent variable from transition net
        trans_logvar = outputs[:,:, (3 * latent_dim):(4 * latent_dim)]
        obs_mean = outputs[:,:,(4 * latent_dim):((4 * latent_dim) + output_obs_dim)] #mean of observation from  generative net
        obs_logvar = outputs[:,:,((4 * latent_dim) + output_obs_dim):]
        target = inputs[:,:,2:4]

        #transform logvar to std
        infer_std = tf.sqrt(tf.exp(infer_logvar))
        trans_std = tf.sqrt(tf.exp(trans_logvar))
        obs_std = tf.sqrt(tf.exp(obs_logvar))

        #computing loss at each time step
        time_step_loss = []
        for i in range(tf.shape(outputs)[0].numpy()):
            #distribution of each module
            infer_dist = tfd.MultivariateNormalDiag(infer_mean[i],infer_std[i])
            trans_dist = tfd.MultivariateNormalDiag(trans_mean[i],trans_std[i])
            obs_dist = tfd.MultivariateNormalDiag(obs_mean[i],obs_std[i])

            #log likelihood of observation
            likelihood = obs_dist.prob(target[i]) #shape = 1D = batch_size
            likelihood = tf.clip_by_value(likelihood, 1e-37, 1)
            log_likelihood = tf.log(likelihood)

            #KL of (q|p)
            kl = tfd.kl_divergence(infer_dist, trans_dist) #shape = batch_size

            #the loss
            loss = - log_likelihood + kl
            time_step_loss.append(loss)

        time_step_loss = tf.convert_to_tensor(time_step_loss)        
        overall_loss = tf.reduce_sum(time_step_loss)
        overall_loss = tf.cast(overall_loss, dtype='float32')

        return overall_loss

0 Answers0