I'm currently trying to implement a version of variational autoencoder in a sequential setting. I work on TensorFlow with eager execution mode.
As the problem setting, I have two sequences of variables: action (2D) and observation (2D). It is assumed that action affects the observation. The goal is to model (recover) the observation sequence, conditioned on actions. To do so, we 'ask for help' from latent variables (hidden), just similar to VAE, yet in a sequential fashion.
Below is the generative flow of the problem. o, a, and s are observation, action, and latent variable, respectively.
The generative flow of the modeling problem
Just like VAE, neural nets are used to parameterize the distributions of involved variables. Here I assume that all variables follow some Multivariate Normal with full diagonal covariance matrix (without covariance).
There are three neural nets involved here: inference, transition, and generative net. Each of them emits mean and log variance vector of the corresponding variable. The picture of all those nets together is given here: Three nets relation
Agendas for the picture: A: Mean of latent variable from inference net B: Log variance of latent variable from inference net C: Mean of latent variable from transition net D: Log variance of latent variable from transition net E: Mean of predicted observation F: Log variance of predicted observation
The loss we want to minimize is negative ELBO. While ELBO itself equals to log likelihood of the true observation, given E and F minus Kullback-Leibler distance between A-B and C-D.
Since the problem results both non-standard RNN cell and input-output flow, I create my own RNNcell and later passing it to tf.nn.raw_rnn API.
Below is my code implementation:
from __future__ import absolute_import, division, print_function
import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()
import os
import numpy as np
import math
from tensorflow.keras.layers import Input, Dense, Lambda
from tensorflow.keras.models import Model
#training data
inputs #shape (time_step, batch_size, input_depth) = (20,1000,4)
#global configuration variables
max_time = 20
batch_size = 1000
latent_dim = 4
#initial state
init_state = tf.zeros([batch_size, latent_dim])
#sampling and reparameterizing function
def sampling(args):
mean, logvar = args
batch = batch_size
dim = latent_dim
# by default, random_normal has mean = 0 and std = 1.0
epsilon = tf.random_normal(shape=(batch, dim))
return mean + tf.exp(0.5 * logvar) * epsilon
#class of the model, in fact it is also an RNN cell
class SSM(tf.keras.Model):
def __init__(self, latent_dim = 4, observation_dim = 2):
super(SSM, self).__init__()
self.latent_dim = latent_dim
self.observation_dim = observation_dim
self.input_dim = (self.latent_dim + self.observation_dim + 2) + (self.latent_dim + 2) # input of inference net and transition net
#inference net
inference_input = Input(shape=(self.latent_dim + self.observation_dim + 2,), name='inference_input')
layer_1 = Dense(30, activation='tanh')(inference_input)
layer_2 = Dense(30, activation='tanh')(layer_1)
inference_mean = Dense(latent_dim, name='inference_mean')(layer_2)
inference_logvar = Dense(latent_dim, name='inference_logvar')(layer_2)
s = Lambda(sampling, output_shape=(latent_dim,), name='s')([inference_mean, inference_logvar])
self.inference_net = Model(inference_input, [inference_mean, inference_logvar, s], name='inference_net')
#transition net
trans_input = Input(shape=(self.latent_dim + 2,), name='transition_net')
layer_1a = Dense(20, activation='tanh')(trans_input)
layer_2a = Dense(20, activation='tanh')(layer_1a)
trans_mean = Dense(latent_dim, name='trans_mean')(layer_2a)
trans_logvar = Dense(latent_dim, name='trans_logvar')(layer_2a)
self.transition_net = Model(trans_input, [trans_mean, trans_logvar], name='transition_net')
#generative net
latent_inputs = Input(shape=(self.latent_dim,), name='s_sampling')
layer_3 = Dense(10, activation='tanh')(latent_inputs)
layer_4 = Dense(10, activation='tanh')(layer_3)
obs_mean = Dense(observation_dim, name='observation_mean')(layer_4)
obs_logvar = Dense(observation_dim, name='observation_logvar')(layer_4)
self.generative_net = Model(latent_inputs, [obs_mean, obs_logvar], name='generative_net')
@property
def state_size(self):
return self.latent_dim
@property
def output_size(self):
return (2 * self.latent_dim) + (2 * self.latent_dim) + (2 * self.observation_dim) #mean&logvar of latent in infer, trans & observation in generative
@property
def zero_state(self):
return init_state #global variable we have defined
def __call__(self, inputs, state):
#next state is the sampled latent variables from inference net
next_state = self.inference_net(inputs[:,:(self.latent_dim + self.observation_dim + 2)])[2]
#mean and logvar of latent variables, inference net version
#note that the input of RNN cell is 14 dimension, the first 8 = latent_dim + observation_dim + 2 is for input inference net
#and the remaining 6 (without observation) is for transition net
infer_mean = self.inference_net(inputs[:,:(self.latent_dim + self.observationdim + 2)])[0]
infer_logvar = self.inference_net(inputs[:,:(self.latent_dim + self.observation_dim + 2)])[1]
#mean and logvar of latent variables, transition net version
trans_mean = self.transition_net(inputs[:,(self.latent_dim + self.observation_dim + 2):])[0]
trans_logvar = self.transition_net(inputs[:,(self.latent_dim + self.observation_dim + 2):])[1]
#mean and logvar of observation
obs_mean = self.generative_net(next_state)[0]
obs_logvar = self.generative_net(next_state)[1]
#output of RNN cell are concatenation of all
output = tf.concat([infer_mean, infer_logvar, trans_mean, trans_logvar, obs_mean, obs_logvar], -1)
return output, next_state
#define a class with instant having method called loop_fn (needed for tf.nn.raw_rnn)
class SuperLoop:
def __init__(self, inputs, output_dim = 20): # 20 = 4*latent_dim + 2*observation_dim
inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time, clear_after_read=False)
inputs_ta = inputs_ta.unstack(inputs) #ini datanya
self.inputs_ta = inputs_ta
self.output_dim = output_dim
self.output_ta = tf.TensorArray(dtype=tf.float32, size=max_time) #for saving the states
def loop_fn(self,time, cell_output, cell_state, loop_state):
emit_output = cell_output # ==None for time == 0
if cell_output is None: # when time == 0
next_cell_state = init_state
emit_output = tf.zeros([self.output_dim])
next_loop_state = self.output_ta
else :
emit_output = cell_output
next_cell_state = cell_state
#saving the sampled latent variables
next_loop_state = loop_state.write(time-1, next_cell_state)
elements_finished = (time >= max_time)
finished = tf.reduce_all(elements_finished)
if finished :
next_input = tf.zeros(shape=(self.output_dim), dtype=tf.float32)
else :
#cell's next input
next_input = tf.concat([self.inputs_ta.read(time), next_cell_state, self.inputs_ta.read(time)[:,:2], next_cell_state], -1)
return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
def SSM_model(inputs, RNN_cell = SSM(), output_dim = 20):
superloop = SuperLoop(inputs, output_dim)
outputs_ta, final_state, final_loop_state = tf.nn.raw_rnn(RNN_cell, superloop.loop_fn)
#outputs_ta is stilltensor array, hence need to be stacked
obs = outputs_ta.stack()
obs = tf.where(tf.is_nan(obs), tf.zeros_like(obs), obs)
#final loop state contains the sampled latent variables
latent = final_loop_state.stack()
latent = tf.where(tf.is_nan(latent), tf.zeros_like(latent), latent)
observation_latent = [obs, latent]
return observation_latent
#cell == model instant
model = SSM()
#Define the loss: negative of ELBO, ElBO = log p(o|s) - KL(infer|trans)
def KL(infer_mean, infer_logvar, trans_mean, trans_logvar, latent_dim = 4):
var_gamma = tf.exp(trans_logvar)
var_phi = tf.exp(infer_logvar)
sgm_gamma = tf.matrix_diag(var_gamma) #shape (20,1000,4,4)
sgm_phi = tf.matrix_diag(var_phi)
eps = 10e-5 #to ensure nonsingularity
'''
analytic expression of KL divergence value between 2 multivariate normal
'''
KL_term = 0.5 * (term_1 + term_2 - term_3 + term_4)
return KL_term
#log of probability p(o|s)
def log_prob(value, obs_mean, obs_logvar, observation_dim = 2):
var_theta = tf.exp(obs_logvar)
sgm_theta = tf.matrix_diag(var_theta)
eps = 10e-5
'''
first compute the likelihood of p(value) ~ Multivariate Normal(obs_mean, sgm_theta)
then compute the log of it = logprob
'''
return logprob
def loss(model, inputs, latent_dim = 4, observation_dim = 2):
outputs = SSM_model(inputs, model)[0] #only need the output of net to compute loss
infer_mean = outputs[:,:,:latent_dim]
infer_logvar = outputs[:,:,latent_dim : (2 * latent_dim)]
trans_mean = outputs[:,:,(2 * latent_dim):(3 * latent_dim)]
trans_logvar = outputs[:,:, (3 * latent_dim):(4 * latent_dim)]
obs_mean = outputs[:,:,(4 * latent_dim):((4 * latent_dim) + observation_dim)]
obs_logvar = outputs[:,:,((4 * latent_dim) + observation_dim):]
#logprob term
value = inputs[:,:,2:4] #observation location in inputs
logprob = log_prob(value, obs_mean, obs_logvar, output_obs_dim)
logprob = tf.reduce_mean(logprob)
#KL term
KL_term = KL(infer_mean, infer_logvar, trans_mean, trans_logvar, latent_dim)
KL_term = tf.reduce_mean(KL_term)
return KL_term - logprob
#computing gradient function
def compute_gradients(model, x):
with tf.GradientTape() as tape:
loss_value = loss(model, x)
return tape.gradient(loss_value, model.trainable_variables), loss_value
compute_gradients(model, inputs)
The last line results zero gradients tape for transition net and generative net, hence I can't proceed any further. Does anyone have some clue why the gradients of transition and generative net are zero? I guess, my code on creating the model is still wrong. But I don't have any idea to refine it.