So this question is about GANs.
I am trying to do a trivial example for my own proof of concept; namely, generate images of hand written digits (MNIST). While most will approach this via deep convolutional gans (dgGANs), I am just trying to achieve this via the 1D array (i.e. instead of 28x28 gray-scale pixel values, a 28*28 1d array).
This git repo features a "vanilla" gans which treats the MNIST dataset as a 1d array of 784 values. Their output values look pretty acceptable so I wanted to do something similar.
Import statements
from __future__ import print_function
import matplotlib as mpl
from matplotlib import pyplot as plt
import mxnet as mx
from mxnet import nd, gluon, autograd
from mxnet.gluon import nn, utils
import numpy as np
import os
from math import floor
from random import random
import time
from datetime import datetime
import logging
ctx = mx.gpu()
np.random.seed(3)
Hyper parameters
batch_size = 100
epochs = 100
generator_learning_rate = 0.001
discriminator_learning_rate = 0.001
beta1 = 0.5
latent_z_size = 100
Load data
mnist = mx.test_utils.get_mnist()
# convert imgs to arrays
flattened_training_data = mnist["test_data"].reshape(10000, 28*28)
define models
G = nn.Sequential()
with G.name_scope():
G.add(nn.Dense(300, activation="relu"))
G.add(nn.Dense(28 * 28, activation="tanh"))
D = nn.Sequential()
with D.name_scope():
D.add(nn.Dense(128, activation="relu"))
D.add(nn.Dense(64, activation="relu"))
D.add(nn.Dense(32, activation="relu"))
D.add(nn.Dense(2, activation="tanh"))
loss = gluon.loss.SoftmaxCrossEntropyLoss()
init stuff
G.initialize(mx.init.Normal(0.02), ctx=ctx)
D.initialize(mx.init.Normal(0.02), ctx=ctx)
trainer_G = gluon.Trainer(G.collect_params(), 'adam', {"learning_rate": generator_learning_rate, "beta1": beta1})
trainer_D = gluon.Trainer(D.collect_params(), 'adam', {"learning_rate": discriminator_learning_rate, "beta1": beta1})
metric = mx.metric.Accuracy()
dynamic plot (for juptyer notebook)
import matplotlib.pyplot as plt
import time
def dynamic_line_plt(ax, y_data, colors=['r', 'b', 'g'], labels=['Line1', 'Line2', 'Line3']):
x_data = []
y_max = 0
y_min = 0
x_min = 0
x_max = 0
for y in y_data:
x_data.append(list(range(len(y))))
if max(y) > y_max:
y_max = max(y)
if min(y) < y_min:
y_min = min(y)
if len(y) > x_max:
x_max = len(y)
ax.set_ylim(y_min, y_max)
ax.set_xlim(x_min, x_max)
if ax.lines:
for i, line in enumerate(ax.lines):
line.set_xdata(x_data[i])
line.set_ydata(y_data[i])
else:
for i in range(len(y_data)):
l = ax.plot(x_data[i], y_data[i], colors[i], label=labels[i])
ax.legend()
fig.canvas.draw()
train
stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
logging.basicConfig(level=logging.DEBUG)
# arrays to store data for plotting
loss_D = nd.array([0], ctx=ctx)
loss_G = nd.array([0], ctx=ctx)
acc_d = nd.array([0], ctx=ctx)
labels = ['Discriminator Loss', 'Generator Loss', 'Discriminator Acc.']
%matplotlib notebook
fig, ax = plt.subplots(1, 1)
ax.set_xlabel('Time')
ax.set_ylabel('Loss')
dynamic_line_plt(ax, [loss_D.asnumpy(), loss_G.asnumpy(), acc_d.asnumpy()], labels=labels)
for epoch in range(epochs):
tic = time.time()
data_iter.reset()
for i, batch in enumerate(data_iter):
####################################
# Update Disriminator: maximize log(D(x)) + log(1-D(G(z)))
####################################
# extract batch of real data
data = batch.data[0].as_in_context(ctx)
# add noise
# Produce our noisey input to the generator
latent_z = mx.nd.random_normal(0,1,shape=(batch_size, latent_z_size), ctx=ctx)
# soft and noisy labels
# real_label = mx.nd.ones((batch_size, ), ctx=ctx) * nd.random_uniform(.7, 1.2, shape=(1)).asscalar()
# fake_label = mx.nd.ones((batch_size, ), ctx=ctx) * nd.random_uniform(0, .3, shape=(1)).asscalar()
# real_label = nd.random_uniform(.7, 1.2, shape=(batch_size), ctx=ctx)
# fake_label = nd.random_uniform(0, .3, shape=(batch_size), ctx=ctx)
real_label = mx.nd.ones((batch_size, ), ctx=ctx)
fake_label = mx.nd.zeros((batch_size, ), ctx=ctx)
with autograd.record():
# train with real data
real_output = D(data)
errD_real = loss(real_output, real_label)
# train with fake data
fake = G(latent_z)
fake_output = D(fake.detach())
errD_fake = loss(fake_output, fake_label)
errD = errD_real + errD_fake
errD.backward()
trainer_D.step(batch_size)
metric.update([real_label, ], [real_output,])
metric.update([fake_label, ], [fake_output,])
####################################
# Update Generator: maximize log(D(G(z)))
####################################
with autograd.record():
output = D(fake)
errG = loss(output, real_label)
errG.backward()
trainer_G.step(batch_size)
####
# Plot Loss
####
# append new data to arrays
loss_D = nd.concat(loss_D, nd.mean(errD), dim=0)
loss_G = nd.concat(loss_G, nd.mean(errG), dim=0)
name, acc = metric.get()
acc_d = nd.concat(acc_d, nd.array([acc], ctx=ctx), dim=0)
# plot array
dynamic_line_plt(ax, [loss_D.asnumpy(), loss_G.asnumpy(), acc_d.asnumpy()], labels=labels)
name, acc = metric.get()
metric.reset()
logging.info('Binary training acc at epoch %d: %s=%f' % (epoch, name, acc))
logging.info('time: %f' % (time.time() - tic))
output
img = G(mx.nd.random_normal(0,1,shape=(100, latent_z_size), ctx=ctx))[0].reshape((28, 28))
plt.imshow(img.asnumpy(),cmap='gray')
plt.show()
Now this doesn't get nearly as good as the repo's example from above. Although fairly similar.
Thus I was wondering if you could take a look and figure out why:
- the colors are inverted
- why the results are sub par
I have been fiddling around with this trying a lot of various things to improve the results (I will list this in a second), but for the MNIST dataset this really shouldn't be needed.
Things I have tried (and I have also tried a host of combinations):
- increasing the generator network
- increasing the discriminator network
- using soft labeling
- using noisy labeling
- batch norm after every layer in the generator
- batch norm of the data
- normalizing all values between -1 and 1
- leaky relus in the generator
- drop out layers in the generator
- increased learning rate of discriminator compared to generator
- decreased learning rate of i compared to generator
Please let me know if you have any ideas.