2

I stumbled across a strange phenomenon while playing around with variational autoencoders. The problem is quite simple to describe:

When defining the loss function for the VAE, you have to use some kind of reconstruction error. I decided to use my own implementation of cross-entropy, as I wasn't able to get reasonable results with any function provided by tensorflow. It looks like this:

x_hat = tf.contrib.layers.fully_connected(fc2,
                                  input_dim,
                                  activation_fn=tf.sigmoid)

## Define the loss

reconstruction_loss = -tf.reduce_sum(
    x * tf.log(epsilon + x_hat) + 
    (1 - x) * tf.log(epsilon + 1 - x_hat),
    axis=1) 

It uses the output of the reconstructed layer, which applies the sigmoid function to get it to the [0; 1] range. Now, I wanted to apply the sigmoid within the loss function and changed it to

x_hat = tf.contrib.layers.fully_connected(fc2,
                                  input_dim,
                                  activation_fn=None)

## Define the loss

reconstruction_loss = -tf.reduce_sum(
    x * tf.log(epsilon + tf.sigmoid(x_hat)) + 
    (1 - x) * tf.log(epsilon + 1 - tf.sigmoid(x_hat)),
    axis=1) 

I'm convinced that this should provide nearly identical results. In practice, though, this second attempt results in weird grey pictures. The originals seem blurry and much brighter, too. First the okay version, then the alternative "wrong" version.

for original code for 2nd attempt

Can someone explain to me what causes this weird behavior?

If you want to test it yourself, below is my source code. You have to comment the respective blocks in or out to get the results. Thanks!

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np

mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
n_samples = mnist.train.num_examples
input_dim = mnist.train.images[0].shape[0]
inter_dim = 256
encoding_dim = 5
epsilon = 1e-10
learning_rate = 1e-4
n_epochs = 20
batch_size = 100
width = 28

## Define the variational autoencoder model 

x = tf.placeholder(dtype=tf.float32,
               shape=[None, input_dim],
               name='x')

fc1 = tf.contrib.layers.fully_connected(x,
                                   inter_dim,
                                   activation_fn=tf.nn.relu)

z_mean = tf.contrib.layers.fully_connected(fc1,
                                       encoding_dim,
                                       activation_fn=None)
z_log_var = tf.contrib.layers.fully_connected(fc1,
                                          encoding_dim,
                                          activation_fn=None)

eps = tf.random_normal(shape=tf.shape(z_log_var),
                   mean=0,
                   stddev=1,
                   dtype=tf.float32)
z = z_mean + tf.exp(z_log_var / 2) * eps

fc2 = tf.contrib.layers.fully_connected(z,
                                    inter_dim,
                                    activation_fn=tf.nn.relu)

x_hat = tf.contrib.layers.fully_connected(fc2,
                                      input_dim,
                                      activation_fn=tf.sigmoid)
                                     #activation_fn=None)
## Define the loss

reconstruction_loss = -tf.reduce_sum(
    x * tf.log(epsilon + x_hat) + 
    (1 - x) * tf.log(epsilon + 1 - x_hat),
    axis=1) 

ALTERNATIVE LOSS W/ APPLYING SIGMOID, REMOVED ACTIVATION FROM OUTPUT LAYER
'''
reconstruction_loss = -tf.reduce_sum(
    x * tf.log(epsilon + tf.sigmoid(x_hat)) + 
    (1 - x) * tf.log(epsilon + 1 - tf.sigmoid(x_hat)),
    axis=1)
'''

KL_div = -.5 * tf.reduce_sum(
    1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
    axis=1)

total_loss = tf.reduce_mean(reconstruction_loss + KL_div)

## Define the training operator

train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(total_loss)

## Run it

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())

    for epoch in range(n_epochs):
        for _ in range(n_samples // batch_size):
            batch = mnist.train.next_batch(batch_size)

            _, loss, recon_loss, KL_loss = sess.run([train_op,
                                                total_loss,
                                                reconstruction_loss,
                                                KL_div],
                                        feed_dict={x:batch[0]})
        print('[Epoch {}] loss: {}'.format(epoch, loss))
    print('Training Done')

    ## Reconstruct a few samples to validate the training

    batch = mnist.train.next_batch(100)

    x_reconstructed = sess.run(x_hat, feed_dict={x:batch[0]})

    n = np.sqrt(batch_size).astype(np.int32)
    I_reconstructed = np.empty((width*n, 2*width*n))
    for i in range(n):
        for j in range(n):
            x = np.concatenate(
                (x_reconstructed[i*n+j, :].reshape(width, width),
                 batch[0][i*n+j, :].reshape(width, width)),
                axis=1
            )
            I_reconstructed[i*width:(i+1)*width, j*2*width:(j+1)*2*width] = x

    fig = plt.figure()
    plt.imshow(I_reconstructed, cmap='gray')

EDIT1: SOLUTION

Thanks to @xdurch0, I was made aware of the fact that the reconstructed output is no longer rescaled via the sigmoid function. That means the sigmoid has to be applied on the image before plotting it. Just modify the output:

x_reconstructed = sess.run(tf.sigmoid(x_hat), feed_dict={x:batch[0]})
DocDriven
  • 3,726
  • 6
  • 24
  • 53
  • 1
    In the case where you apply the sigmoid within the cost function, shouldn't you also apply sigmoid to the reconstructions when doing the plots? I don't see you doing that. Without the squashing of the sigmoid the values are much more extreme, which also explains why the plots of the original images (which have less extreme values, constrained between 0 and 1) hare messed up. – xdurch0 May 23 '18 at 10:15
  • Thank you, that was indeed the issue. I will edit it into my question. But that makes me wonder: if I decide to use something like `tf.nn.sigmoid_cross_entropy_with_logits` which expects logits as argument, I have to manually apply a sigmoid function whenever I want to use the outputs. This seems clunky to me. – DocDriven May 23 '18 at 11:10

0 Answers0