I played around with your code (from above and Github) and found the following:
Sigmoid: when your code loads the MNIST dataset, you apply a Transform to normalize the data, but your Autoencoder model uses nn.Sigmoid()
as its final layer, which forces the data to be in the range of [0, 1] (but the normalized data is more like [-.4242, 2.8215]. Commenting-out the sigmoid layer helps greatly reduce the loss during training.
Softmax: I understand why you include the nn.Softmax()
layer - to try and force the learned 10 features to be used sparsely for reconstructing each image. It does help raise the test accuracy in some cases. After trying a few ideas (like annealing a softmax temperature), it feels like a single float to reconstruct each class of digit is just insufficient.
Clustering: another way to use the features to predict 1 of 10 positions for each image is by clustering the feature representations (over some set of training/dev samples). I tried this and found it help raised the test accuracy.
CNN: I found a different CNN AE model from here that works a little bit better in the experiments I ran.
Optimizer: I found that the Adam optimizer with LR=.001
works better that the ad-hoc values I tried with SGD, Adam, and Adadelta.
Finally, I found that wrapping the img with Variable()
is not needed, so I removed that.
Below is the final code I ended up with. After 25 epochs of training:
- loss is around .03
- test accuracy is around 60%
Training to 100 epochs doesn't seem to improve things. Here is a sample of the before/after digits at epoch 8 where the loss=.0436:

import os
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
# # work around sklearn warning
os.environ["OMP_NUM_THREADS"] = "4"
import sklearn.cluster as cluster
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class Autoencoder(nn.Module):
def __init__(self, d_hidden=10, use_softmax=False):
super(Autoencoder,self).__init__()
self.encoder = nn.Sequential(
# 28 x 28
nn.Conv2d(1, 4, kernel_size=5),
# 4 x 24 x 24
nn.ReLU(True),
nn.Conv2d(4, 8, kernel_size=5),
nn.ReLU(True),
# 8 x 20 x 20 = 3200
nn.Flatten(),
nn.Linear(3200, d_hidden),
# d_hidden
#nn.Softmax(dim=-1),
)
self.use_softmax = use_softmax
self.decoder = nn.Sequential(
# d_hidden
nn.Linear(d_hidden, 400),
# 400
nn.ReLU(True),
nn.Linear(400, 4000),
# 4000
nn.ReLU(True),
nn.Unflatten(1, (10, 20, 20)),
# 10 x 20 x 20
nn.ConvTranspose2d(10, 10, kernel_size=5),
# 24 x 24
nn.ConvTranspose2d(10, 1, kernel_size=5),
# 28 x 28
#nn.Sigmoid(),
)
def forward(self, x, temperature):
features = self.encoder(x)
if self.use_softmax:
features = torch.softmax(features/temperature, dim=-1)
output = self.decoder(features)
return output
def get_features(self, x):
features = self.encoder(x)
return features
class NewAutoencoder(nn.Module):
def __init__(self, d_hidden=64, use_softmax=False):
super(NewAutoencoder, self).__init__()
self.encoder = nn.Sequential( # like the Composition layer you built
nn.Conv2d(1, 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, d_hidden, 7)
)
if use_softmax:
self.encoder.add_module("softmax", nn.Softmax(dim=-1))
self.decoder = nn.Sequential(
nn.ConvTranspose2d(d_hidden, 32, 7),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
#nn.Sigmoid()
)
def forward(self, x, temperature):
x = self.encoder(x)
x = self.decoder(x)
return x
def get_features(self, x):
output = self.encoder(x).squeeze(-1).squeeze(-1)
return output
def show_samples(model, dataloader, device, epoch, temperature):
model.eval()
img, labels = list(dataloader)[0]
img = img.to(device)
output = model(img, temperature)
inp = img[0:10, 0, :, :].squeeze().detach().cpu()
out = output[0:10, 0, :, :].squeeze().detach().cpu()
# Just some trick to concatenate first ten images next to each other
inp = inp.permute(1,0,2).reshape(28, -1).numpy()
out = out.permute(1,0,2).reshape(28, -1).numpy()
combined = np.vstack([inp, out])
plt.title("epoch: {}".format(epoch))
plt.imshow(combined)
plt.draw()
plt.pause(0.1)
def calc_position_to_label_mapping(model, alignloader, device, quick, cluster_map):
model.eval()
position_counts_by_label = defaultdict(dict) # key=label, value=dict(key=position, value-count)
labels_by_position = {key:None for key in range(10)}
# collect counts of each position, by label
for images, labels in alignloader:
images = images.to(device)
output = model.get_features(images)
if cluster_map:
feature_data = output.cpu().detach().numpy()
preds = cluster_map.predict(feature_data)
else:
preds = torch.argmax(output, dim=-1)
for lab, pred in zip(labels, preds):
label = int(lab)
position = int(pred)
pc = position_counts_by_label[label]
if position in pc:
pc[position] += 1
else:
pc[position] = 1
'''
Note: at this point, we could have more a particular position
being the best predictor of more than 1 label. Since each position can only
predict a single label, we want to choose the overall assignments of position -> label
that will maximize our accuracy. The below algorithm estimates this best assignment:
- normalize all counts by label
- for the remaining labels/positions:
- find position->label assigment with greatest difference between top scoring #1 assignment and #2
assignment within a label, across all labels
- record that position->label assignment and remove position and assignment from pool.
- repeat above step until all labels have been assigned a (unique) position.
'''
# normalize counts by label
for label, pc in position_counts_by_label.items():
total = sum(pc.values())
for key in pc:
pc[key] /= total
# repeat until done
remaining_positions = {key:1 for key in range(10)}
while position_counts_by_label:
# find strongest position/label assignment
best = None
for label, pc in position_counts_by_label.items():
if len(pc) == 0:
# no remaining positions predicted this label
position = next(iter(remaining_positions))
best = (label, position, 1)
break
if len(pc) == 1:
# automatic winner
best = (label, next(iter(pc)), 1)
break
pcx = dict(pc)
key1 = max(pcx, key=pcx.get)
del pcx[key1]
key2 = max(pcx, key=pcx.get)
diff = pc[key1] - pc[key2]
#diff = pc[key1]
if best is None or diff > best[2]:
best = (label, key1, diff)
# record chosen position/label
label, position, score = best
labels_by_position[position] = label
# remove position/label from pool
del position_counts_by_label[label]
del remaining_positions[position]
for pc in position_counts_by_label.values():
if position in pc:
del pc[position]
return labels_by_position
def cluster_features(model, alignloader, device, quick):
all_features = None
for images, unused_labels in alignloader:
images = images.to(device)
features = model.get_features(images)
if all_features is None:
all_features = features
else:
all_features = torch.vstack( [all_features, features] )
feature_data = all_features.cpu().detach().numpy()
kmeans = cluster.KMeans(n_clusters=10)
kmeans.fit(feature_data)
return kmeans
def eval_test(model, testloader, device, quick, labels_by_position, cluster_map):
correct = 0
samples = 0
for images, labels in testloader:
images = images.to(device)
output = model.get_features(images)
if cluster_map:
feature_data = output.cpu().detach().numpy()
preds = cluster_map.predict(feature_data)
else:
preds = torch.argmax(output, dim=-1)
for lab, pred in zip(labels, preds):
label = int(lab)
position = int(pred)
pred_label = labels_by_position[position]
if label == pred_label:
correct += 1
samples += 1
print("labels_by_position: {}".format(list(labels_by_position.values())))
test_acc = correct/samples
name = "Estimated" if quick else "Total"
print("{} test acc: {:.4f} (samples: {:,})".format(name, test_acc, len(testloader.dataset)))
def evaluate(model, testloader, device, quick, use_clustering):
if use_clustering:
cluster_map = cluster_features(model, testloader, device, True)
labels_by_position = calc_position_to_label_mapping(model, testloader, device, True, cluster_map)
eval_test(model, testloader, device, True, labels_by_position, cluster_map)
else:
labels_by_position = calc_position_to_label_mapping(model, testloader, device, True, None)
eval_test(model, testloader, device, True, labels_by_position, None)
def train():
num_epochs = 25
batch_size = 64
test_samples = 1000
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
trainset = tv.datasets.MNIST(root='d:/.data/mnist', train=True, download=True, transform=transform)
testset = tv.datasets.MNIST(root='d:/.data/mnist', train=False, download=True, transform=transform)
indexes = list(range(test_samples))
quick_testset = torch.utils.data.Subset(testset, indexes)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
quick_testloader = torch.utils.data.DataLoader(quick_testset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda")
use_orig_ae = False
use_softmax = False
use_clustering = True
d_hidden = 32 if use_clustering else 10
plot_before_after_digits = True
if use_orig_ae:
model = Autoencoder(d_hidden=d_hidden, use_softmax=use_softmax).to(device)
else:
model = NewAutoencoder(d_hidden=d_hidden, use_softmax=use_softmax).to(device)
temperature = 1
distance = nn.MSELoss()
#distance = nn.L1Loss()
#optimizer = torch.optim.SGD(model.parameters(), lr=.05) # , lr=.01, momentum=0.5)
#optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
optimizer = torch.optim.Adam(model.parameters(), lr=.001)
for epoch in range(num_epochs):
model.train()
for data in trainloader:
optimizer.zero_grad()
img, _ = data
#img = Variable(img).to(device)
img = img.to(device)
output = model(img, temperature)
loss = distance(output, img)
loss.backward()
optimizer.step()
print('epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
evaluate(model, quick_testloader, device, True, use_clustering)
print()
if plot_before_after_digits:
show_samples(model, trainloader, device, epoch+1, temperature)
temperature = .9*temperature
# after training, do final (and full) eval
evaluate(model, testloader, device, False, use_clustering)
_ = input("hit RETURN to dismiss plot and end program")
if __name__ == "__main__":
train()