I stumbled across a strange phenomenon while playing around with variational autoencoders. The problem is quite simple to describe:
When defining the loss function for the VAE, you have to use some kind of reconstruction error. I decided to use my own implementation of cross-entropy, as I wasn't able to get reasonable results with any function provided by tensorflow. It looks like this:
x_hat = tf.contrib.layers.fully_connected(fc2,
input_dim,
activation_fn=tf.sigmoid)
## Define the loss
reconstruction_loss = -tf.reduce_sum(
x * tf.log(epsilon + x_hat) +
(1 - x) * tf.log(epsilon + 1 - x_hat),
axis=1)
It uses the output of the reconstructed layer, which applies the sigmoid function to get it to the [0; 1] range. Now, I wanted to apply the sigmoid within the loss function and changed it to
x_hat = tf.contrib.layers.fully_connected(fc2,
input_dim,
activation_fn=None)
## Define the loss
reconstruction_loss = -tf.reduce_sum(
x * tf.log(epsilon + tf.sigmoid(x_hat)) +
(1 - x) * tf.log(epsilon + 1 - tf.sigmoid(x_hat)),
axis=1)
I'm convinced that this should provide nearly identical results. In practice, though, this second attempt results in weird grey pictures. The originals seem blurry and much brighter, too. First the okay version, then the alternative "wrong" version.
Can someone explain to me what causes this weird behavior?
If you want to test it yourself, below is my source code. You have to comment the respective blocks in or out to get the results. Thanks!
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import numpy as np
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
n_samples = mnist.train.num_examples
input_dim = mnist.train.images[0].shape[0]
inter_dim = 256
encoding_dim = 5
epsilon = 1e-10
learning_rate = 1e-4
n_epochs = 20
batch_size = 100
width = 28
## Define the variational autoencoder model
x = tf.placeholder(dtype=tf.float32,
shape=[None, input_dim],
name='x')
fc1 = tf.contrib.layers.fully_connected(x,
inter_dim,
activation_fn=tf.nn.relu)
z_mean = tf.contrib.layers.fully_connected(fc1,
encoding_dim,
activation_fn=None)
z_log_var = tf.contrib.layers.fully_connected(fc1,
encoding_dim,
activation_fn=None)
eps = tf.random_normal(shape=tf.shape(z_log_var),
mean=0,
stddev=1,
dtype=tf.float32)
z = z_mean + tf.exp(z_log_var / 2) * eps
fc2 = tf.contrib.layers.fully_connected(z,
inter_dim,
activation_fn=tf.nn.relu)
x_hat = tf.contrib.layers.fully_connected(fc2,
input_dim,
activation_fn=tf.sigmoid)
#activation_fn=None)
## Define the loss
reconstruction_loss = -tf.reduce_sum(
x * tf.log(epsilon + x_hat) +
(1 - x) * tf.log(epsilon + 1 - x_hat),
axis=1)
ALTERNATIVE LOSS W/ APPLYING SIGMOID, REMOVED ACTIVATION FROM OUTPUT LAYER
'''
reconstruction_loss = -tf.reduce_sum(
x * tf.log(epsilon + tf.sigmoid(x_hat)) +
(1 - x) * tf.log(epsilon + 1 - tf.sigmoid(x_hat)),
axis=1)
'''
KL_div = -.5 * tf.reduce_sum(
1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
axis=1)
total_loss = tf.reduce_mean(reconstruction_loss + KL_div)
## Define the training operator
train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(total_loss)
## Run it
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(n_epochs):
for _ in range(n_samples // batch_size):
batch = mnist.train.next_batch(batch_size)
_, loss, recon_loss, KL_loss = sess.run([train_op,
total_loss,
reconstruction_loss,
KL_div],
feed_dict={x:batch[0]})
print('[Epoch {}] loss: {}'.format(epoch, loss))
print('Training Done')
## Reconstruct a few samples to validate the training
batch = mnist.train.next_batch(100)
x_reconstructed = sess.run(x_hat, feed_dict={x:batch[0]})
n = np.sqrt(batch_size).astype(np.int32)
I_reconstructed = np.empty((width*n, 2*width*n))
for i in range(n):
for j in range(n):
x = np.concatenate(
(x_reconstructed[i*n+j, :].reshape(width, width),
batch[0][i*n+j, :].reshape(width, width)),
axis=1
)
I_reconstructed[i*width:(i+1)*width, j*2*width:(j+1)*2*width] = x
fig = plt.figure()
plt.imshow(I_reconstructed, cmap='gray')
EDIT1: SOLUTION
Thanks to @xdurch0, I was made aware of the fact that the reconstructed output is no longer rescaled via the sigmoid function. That means the sigmoid has to be applied on the image before plotting it. Just modify the output:
x_reconstructed = sess.run(tf.sigmoid(x_hat), feed_dict={x:batch[0]})