1

I'm currently trying to implement a version of variational autoencoder in a sequential setting. I work on TensorFlow with eager execution mode.

As the problem setting, I have two sequences of variables: action (2D) and observation (2D). It is assumed that action affects the observation. The goal is to model (recover) the observation sequence, conditioned on actions. To do so, we 'ask for help' from latent variables (hidden), just similar to VAE, yet in a sequential fashion.

Below is the generative flow of the problem. o, a, and s are observation, action, and latent variable, respectively.

The generative flow of the modeling problem

Just like VAE, neural nets are used to parameterize the distributions of involved variables. Here I assume that all variables follow some Multivariate Normal with full diagonal covariance matrix (without covariance).

There are three neural nets involved here: inference, transition, and generative net. Each of them emits mean and log variance vector of the corresponding variable. The picture of all those nets together is given here: Three nets relation

Agendas for the picture: A: Mean of latent variable from inference net B: Log variance of latent variable from inference net C: Mean of latent variable from transition net D: Log variance of latent variable from transition net E: Mean of predicted observation F: Log variance of predicted observation

The loss we want to minimize is negative ELBO. While ELBO itself equals to log likelihood of the true observation, given E and F minus Kullback-Leibler distance between A-B and C-D.

Since the problem results both non-standard RNN cell and input-output flow, I create my own RNNcell and later passing it to tf.nn.raw_rnn API.

Below is my code implementation:

from __future__ import absolute_import, division, print_function
import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()

import os
import numpy as np
import math

from tensorflow.keras.layers import Input, Dense, Lambda
from tensorflow.keras.models import Model

#training data
inputs #shape (time_step, batch_size, input_depth) = (20,1000,4)

#global configuration variables
max_time = 20
batch_size = 1000
latent_dim = 4

#initial state
init_state = tf.zeros([batch_size, latent_dim])

#sampling and reparameterizing function
def sampling(args):
    mean, logvar = args
    batch = batch_size
    dim = latent_dim
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = tf.random_normal(shape=(batch, dim))
    return mean + tf.exp(0.5 * logvar) * epsilon

#class of the model, in fact it is also an RNN cell
class SSM(tf.keras.Model):
    def __init__(self, latent_dim = 4, observation_dim = 2):
        super(SSM, self).__init__()
        self.latent_dim = latent_dim
        self.observation_dim = observation_dim
        self.input_dim = (self.latent_dim + self.observation_dim + 2) + (self.latent_dim + 2) # input of inference net and transition net


        #inference net
        inference_input = Input(shape=(self.latent_dim + self.observation_dim + 2,), name='inference_input')
        layer_1 = Dense(30, activation='tanh')(inference_input)
        layer_2 = Dense(30, activation='tanh')(layer_1)
        inference_mean = Dense(latent_dim, name='inference_mean')(layer_2)
        inference_logvar = Dense(latent_dim, name='inference_logvar')(layer_2)        
        s = Lambda(sampling, output_shape=(latent_dim,), name='s')([inference_mean, inference_logvar])
        self.inference_net = Model(inference_input, [inference_mean, inference_logvar, s], name='inference_net')

        #transition net
        trans_input = Input(shape=(self.latent_dim + 2,), name='transition_net')
        layer_1a = Dense(20, activation='tanh')(trans_input)
        layer_2a = Dense(20, activation='tanh')(layer_1a)
        trans_mean = Dense(latent_dim, name='trans_mean')(layer_2a)
        trans_logvar = Dense(latent_dim, name='trans_logvar')(layer_2a)
        self.transition_net = Model(trans_input, [trans_mean, trans_logvar], name='transition_net')

        #generative net
        latent_inputs = Input(shape=(self.latent_dim,), name='s_sampling')
        layer_3 = Dense(10, activation='tanh')(latent_inputs)
        layer_4 = Dense(10, activation='tanh')(layer_3)
        obs_mean = Dense(observation_dim, name='observation_mean')(layer_4)
        obs_logvar = Dense(observation_dim, name='observation_logvar')(layer_4)
        self.generative_net = Model(latent_inputs, [obs_mean, obs_logvar], name='generative_net')

    @property
    def state_size(self):
        return self.latent_dim

    @property
    def output_size(self):
        return (2 * self.latent_dim) + (2 * self.latent_dim) + (2 * self.observation_dim) #mean&logvar of latent in infer, trans & observation in generative

    @property
    def zero_state(self):
        return init_state #global variable we have defined

    def __call__(self, inputs, state):
        #next state is the sampled latent variables from inference net
        next_state = self.inference_net(inputs[:,:(self.latent_dim + self.observation_dim + 2)])[2]

        #mean and logvar of latent variables, inference net version
        #note that the input of RNN cell is 14 dimension, the first 8 = latent_dim + observation_dim + 2 is for input inference net
        #and the remaining 6 (without observation) is for transition net
        infer_mean = self.inference_net(inputs[:,:(self.latent_dim + self.observationdim + 2)])[0]
        infer_logvar = self.inference_net(inputs[:,:(self.latent_dim + self.observation_dim + 2)])[1]

        #mean and logvar of latent variables, transition net version
        trans_mean = self.transition_net(inputs[:,(self.latent_dim + self.observation_dim + 2):])[0]
        trans_logvar = self.transition_net(inputs[:,(self.latent_dim + self.observation_dim + 2):])[1]

        #mean and logvar of observation
        obs_mean = self.generative_net(next_state)[0]
        obs_logvar = self.generative_net(next_state)[1]

        #output of RNN cell are concatenation of all
        output = tf.concat([infer_mean, infer_logvar, trans_mean, trans_logvar, obs_mean, obs_logvar], -1)
        return output, next_state

#define a class with instant having method called loop_fn (needed for tf.nn.raw_rnn)
class SuperLoop:
    def __init__(self, inputs, output_dim = 20): # 20 = 4*latent_dim + 2*observation_dim
        inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time, clear_after_read=False)
        inputs_ta = inputs_ta.unstack(inputs) #ini datanya
        self.inputs_ta = inputs_ta
        self.output_dim = output_dim
        self.output_ta = tf.TensorArray(dtype=tf.float32, size=max_time) #for saving the states

    def loop_fn(self,time, cell_output, cell_state, loop_state):
        emit_output = cell_output # ==None for time == 0
        if cell_output is None: # when time == 0
            next_cell_state = init_state
            emit_output = tf.zeros([self.output_dim])
            next_loop_state = self.output_ta

        else :
            emit_output = cell_output
            next_cell_state = cell_state
            #saving the sampled latent variables
            next_loop_state = loop_state.write(time-1, next_cell_state)

        elements_finished = (time >= max_time)
        finished = tf.reduce_all(elements_finished)

        if finished :
            next_input = tf.zeros(shape=(self.output_dim), dtype=tf.float32)
        else :
            #cell's next input
            next_input = tf.concat([self.inputs_ta.read(time), next_cell_state, self.inputs_ta.read(time)[:,:2], next_cell_state], -1)

        return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)


def SSM_model(inputs, RNN_cell = SSM(), output_dim = 20):
    superloop = SuperLoop(inputs, output_dim)
    outputs_ta, final_state, final_loop_state = tf.nn.raw_rnn(RNN_cell, superloop.loop_fn)
    #outputs_ta is stilltensor array, hence need to be stacked
    obs = outputs_ta.stack()
    obs = tf.where(tf.is_nan(obs), tf.zeros_like(obs), obs)
    #final loop state contains the sampled latent variables
    latent = final_loop_state.stack()
    latent = tf.where(tf.is_nan(latent), tf.zeros_like(latent), latent)
    observation_latent = [obs, latent]
    return observation_latent

#cell == model instant
model = SSM()

#Define the loss: negative of ELBO, ElBO = log p(o|s) - KL(infer|trans)
def KL(infer_mean, infer_logvar, trans_mean, trans_logvar, latent_dim = 4):
    var_gamma = tf.exp(trans_logvar)
    var_phi = tf.exp(infer_logvar)
    sgm_gamma = tf.matrix_diag(var_gamma) #shape (20,1000,4,4)
    sgm_phi = tf.matrix_diag(var_phi)

    eps = 10e-5 #to ensure nonsingularity

    '''
    analytic expression of KL divergence value between 2 multivariate normal
    '''

    KL_term = 0.5 * (term_1 + term_2 - term_3 + term_4)

    return KL_term

#log of probability p(o|s)
def log_prob(value, obs_mean, obs_logvar, observation_dim = 2):
    var_theta = tf.exp(obs_logvar)
    sgm_theta = tf.matrix_diag(var_theta)

    eps = 10e-5

    '''
    first compute the likelihood of p(value) ~ Multivariate Normal(obs_mean, sgm_theta)
    then compute the log of it = logprob
    '''

    return logprob

def loss(model, inputs, latent_dim = 4, observation_dim = 2):
    outputs = SSM_model(inputs, model)[0] #only need the output of net to compute loss
    infer_mean = outputs[:,:,:latent_dim]
    infer_logvar = outputs[:,:,latent_dim : (2 * latent_dim)]
    trans_mean = outputs[:,:,(2 * latent_dim):(3 * latent_dim)]
    trans_logvar = outputs[:,:, (3 * latent_dim):(4 * latent_dim)]
    obs_mean = outputs[:,:,(4 * latent_dim):((4 * latent_dim) + observation_dim)]
    obs_logvar = outputs[:,:,((4 * latent_dim) + observation_dim):]

    #logprob term
    value = inputs[:,:,2:4] #observation location in inputs
    logprob = log_prob(value, obs_mean, obs_logvar, output_obs_dim)
    logprob = tf.reduce_mean(logprob)

    #KL term
    KL_term = KL(infer_mean, infer_logvar, trans_mean, trans_logvar, latent_dim)
    KL_term = tf.reduce_mean(KL_term)

    return KL_term - logprob

#computing gradient function
def compute_gradients(model, x):
  with tf.GradientTape() as tape:
    loss_value = loss(model, x)
  return tape.gradient(loss_value, model.trainable_variables), loss_value

compute_gradients(model, inputs)

The last line results zero gradients tape for transition net and generative net, hence I can't proceed any further. Does anyone have some clue why the gradients of transition and generative net are zero? I guess, my code on creating the model is still wrong. But I don't have any idea to refine it.

0 Answers0