0

I am having a big problem when implementing a Variational Autoencoder, it being that all images end up looking like this: predicted image

when the real image is this: test image

The training set is CIFAR10 and the expected outcome is to manage to construct similar images. While the results seem to have the feature map correctly predicted, I do not understand why the outcome is like this after 50 epochs.

I've used both fewer and higher number of filters, currently at 128. Can this outcome be from the Network Architecture ? Or the few number of epochs ?

The loss function used is MSE and the optimizer RMSPROP.

I've also tried implementing this architecture: https://github.com/chaitanya100100/VAE-for-Image-Generation/blob/master/src/cifar10_train.py having similar results, if not worse.

I am very confused to what it might be the problem here. The method of saving is using matplotlib pyplot to save the predictions and its real counterparts.

Paul
  • 121
  • 1
  • 8

1 Answers1

1

The unhelpful answer is "autoencoders are hard"! Your network has got stuck in a local minima, predicting the average pixel values (across the entire data set) each time.

I suggest:

  • Vary your learning rate, including dramatically reducing it. You'll probably have to train for longer eventually, but just train for a few epochs and check that it hasn't got stuck predicting the same image each time.
  • Adding more filters, as this should make the input to output mapping easier to learn, although this is somewhat defeating the purpose of autoencoders as you are increasing the size of the 'compressed' version.
  • Trying using absolute error for your loss. This helps regress values which are already close to each other (ie less that one apart).

I'm sure others will add suggestions, but I would start with the above.

jmsinusa
  • 1,584
  • 1
  • 13
  • 21
  • I would also suggest using Adam. It proves more consistent results when trying out different network architectures. Also, try BCE instead of MSE. MSE can get stuck in a local minimum when the output is the average of all outputs, which is probably what is happening here. – Oringa Nov 28 '18 at 17:28