4

I'm trying to build a simple autoencoder for MNIST, where the middle layer is just 10 neurons. My hope is that it will learn to classify the 10 digits, and I assume that would lead to the lowest error in the end (wrt reproducing the original image).

I have the following code, which I've already played around with a fair amount. If I run it for up-to 100 epochs, the loss doesn't really go below 1.0, and if I evaluate it, it's obviously not working. What am I missing?

Training:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image

num_epochs = 100
batch_size = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
trainset = tv.datasets.MNIST(root='./data',  train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder,self).__init__()
        self.encoder = nn.Sequential(
            # 28 x 28
            nn.Conv2d(1, 4, kernel_size=5),
            nn.Dropout2d(p=0.2),
            # 4 x 24 x 24
            nn.ReLU(True),
            nn.Conv2d(4, 8, kernel_size=5),
            nn.Dropout2d(p=0.2),
            # 8 x 20 x 20 = 3200
            nn.ReLU(True),
            nn.Flatten(),
            nn.Linear(3200, 10),
            nn.ReLU(True),
            # 10
            nn.Softmax(),
            # 10
            )
        self.decoder = nn.Sequential(
            # 10
            nn.Linear(10, 400),
            nn.ReLU(True),
            # 400
            nn.Unflatten(1, (1, 20, 20)),
            # 20 x 20
            nn.Dropout2d(p=0.2),
            nn.ConvTranspose2d(1, 10, kernel_size=5),
            # 24 x 24
            nn.ReLU(True),
            nn.Dropout2d(p=0.2),
            nn.ConvTranspose2d(10, 1, kernel_size=5),
            # 28 x 28
            nn.ReLU(True),
            nn.Sigmoid(),
            )
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = Autoencoder().cpu()
distance = nn.MSELoss()
#optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = Variable(img).cpu()
        output = model(img)
        loss = distance(output, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

Already the training loss indicates that the thing is not working, but printing out the confusion matrix (which in this case should not necessarily be the identity matrix, since the neurons can be ordered arbitrarily, but should be row-col-reordarable and approximate the identity, if this would work):

import numpy as np

confusion_matrix = np.zeros((10, 10))

batch_size = 20*1000

testset = tv.datasets.MNIST(root='./data',  train=False, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=4)

for data in dataloader:
    imgs, labels = data
    imgs = Variable(imgs).cpu()
    encs = model.encoder(imgs).detach().numpy()
    for i in range(len(encs)):
        predicted = np.argmax(encs[i])
        actual = labels[i]
        confusion_matrix[actual][predicted] += 1
print(confusion_matrix)
Shaido
  • 27,497
  • 23
  • 70
  • 73
Marton Trencseni
  • 887
  • 2
  • 10
  • 18

3 Answers3

6

Autoencoder is technically not used as a classifier in general. They learn how to encode a given image into a short vector and reconstruct the same image from the encoded vector. It is a way of compressing image into a short vector:

Autoencoder

Since you want to train autoencoder with classification capabilities, we need to make some changes to model. First of all, there will be two different losses:

  1. MSE loss: Current autoencoder reconstruction loss. This will force network to output an image as close as possible to given image by taking the compressed representation.
  2. Classification loss: Classic cross entropy should do the trick. This loss will take compressed representation (C dimensional) and target labels to calculate negative log likelihood loss. This loss will force encoder to output compressed representation such that it aligns well with the target class.

I've done a couple of changes to your code to get the combined model working. Firstly, let's see the code:

 import torch
 import torchvision as tv
 import torchvision.transforms as transforms
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.autograd import Variable
 from torchvision.utils import save_image

 num_epochs = 10
 batch_size = 64
 transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))
 ])     
 
 trainset = tv.datasets.MNIST(root='./data',  train=True, download=True, transform=transform)
 testset  = tv.datasets.MNIST(root='./data',  train=False, download=True, transform=transform)
 dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)
 testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=4)
 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 class Autoencoderv3(nn.Module):
     def __init__(self):
         super(Autoencoderv3,self).__init__()
         self.encoder = nn.Sequential(
             nn.Conv2d(1, 4, kernel_size=5),
             nn.Dropout2d(p=0.1),
             nn.ReLU(True),
             nn.Conv2d(4, 8, kernel_size=5),
             nn.Dropout2d(p=0.1),
             nn.ReLU(True),
             nn.Flatten(),
             nn.Linear(3200, 10)
             )
         self.softmax = nn.Softmax(dim=1)
         self.decoder = nn.Sequential(
             nn.Linear(10, 400),
             nn.ReLU(True),
             nn.Unflatten(1, (1, 20, 20)),
             nn.Dropout2d(p=0.1),
             nn.ConvTranspose2d(1, 10, kernel_size=5),
             nn.ReLU(True),
             nn.Dropout2d(p=0.1),
             nn.ConvTranspose2d(10, 1, kernel_size=5)
             )
         
     def forward(self, x):
         out_en = self.encoder(x)
         out = self.softmax(out_en)
         out = self.decoder(out)
         return out, out_en
 
 model = Autoencoderv3().to(device)
 distance   = nn.MSELoss()
 class_loss = nn.CrossEntropyLoss()
 
 optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
 
 mse_multp = 0.5
 cls_multp = 0.5
 
 model.train()
 
 for epoch in range(num_epochs):
     total_mseloss = 0.0
     total_clsloss = 0.0
     for ind, data in enumerate(dataloader):
         img, labels = data[0].to(device), data[1].to(device) 
         output, output_en = model(img)
         loss_mse = distance(output, img)
         loss_cls = class_loss(output_en, labels)
         loss = (mse_multp * loss_mse) + (cls_multp * loss_cls)  # Combine two losses together
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         # Track this epoch's loss
         total_mseloss += loss_mse.item()
         total_clsloss += loss_cls.item()
 
     # Check accuracy on test set after each epoch:
     model.eval()   # Turn off dropout in evaluation mode
     acc = 0.0
     total_samples = 0
     for data in testloader:
         # We only care about the 10 dimensional encoder output for classification
         img, labels = data[0].to(device), data[1].to(device) 
         _, output_en = model(img)   
         # output_en contains 10 values for each input, apply softmax to calculate class probabilities
         prob = nn.functional.softmax(output_en, dim = 1)
         pred = torch.max(prob, dim=1)[1].detach().cpu().numpy() # Max prob assigned to class 
         acc += (pred == labels.cpu().numpy()).sum()
         total_samples += labels.shape[0]
     model.train()   # Enables dropout back again
     print('epoch [{}/{}], loss_mse: {:.4f}  loss_cls: {:.4f}  Acc on test: {:.4f}'.format(epoch+1, num_epochs, total_mseloss / len(dataloader), total_clsloss / len(dataloader), acc / total_samples))
   

This code should now train the model both as a classifier and a generative autoencoder. In general though, this type of approach can be a bit tricky to get the model training. In this case, MNIST data is simple enough to get those two complementary losses train together. In more complex cases like Generative Adversarial Networks (GAN), they apply model training switching, freezing one model etc. to get whole model trained. This autoencoder model trains easily on MNIST without doing those types of tricks:

 epoch [1/10], loss_mse: 0.8928  loss_cls: 0.4627  Acc on test: 0.9463
 epoch [2/10], loss_mse: 0.8287  loss_cls: 0.2105  Acc on test: 0.9639
 epoch [3/10], loss_mse: 0.7803  loss_cls: 0.1574  Acc on test: 0.9737
 epoch [4/10], loss_mse: 0.7513  loss_cls: 0.1290  Acc on test: 0.9764
 epoch [5/10], loss_mse: 0.7298  loss_cls: 0.1117  Acc on test: 0.9762
 epoch [6/10], loss_mse: 0.7110  loss_cls: 0.1017  Acc on test: 0.9801
 epoch [7/10], loss_mse: 0.6962  loss_cls: 0.0920  Acc on test: 0.9794
 epoch [8/10], loss_mse: 0.6824  loss_cls: 0.0859  Acc on test: 0.9806
 epoch [9/10], loss_mse: 0.6733  loss_cls: 0.0797  Acc on test: 0.9814
 epoch [10/10], loss_mse: 0.6671  loss_cls: 0.0764  Acc on test: 0.9813

As you can see, both mse loss and classification loss is decreasing, and accuracy on test set is increasing. In the code, MSE loss and classification loss are added together. This means respective gradients calculated from each loss are fighting against each other to force the network into their direction. I've added loss multiplier to control the contribution from each loss. If MSE has a higher multiplier, network will have more gradients from MSE loss, meaning it will better learn reconstruction, if CLS loss has a higher multiplier, network will get better classification accuracies. You can play with those multiplier to see how end result is changing, but MNIST is a very easy dataset so differences might be hard to see maybe. Currently, it doesn't do too bad at reconstructing inputs:

 import numpy as np
 import matplotlib.pyplot as plt
 
 model.eval()
 img, labels = list(dataloader)[0]
 img = img.to(device)
 output, output_en = model(img)
 inp = img[0:10, 0, :, :].squeeze().detach().cpu()
 out = output[0:10, 0, :, :].squeeze().detach().cpu()
 
 # Just some trick to concatenate first ten images next to each other
 inp = inp.permute(1,0,2).reshape(28, -1).numpy()
 out = out.permute(1,0,2).reshape(28, -1).numpy()
 combined = np.vstack([inp, out])
 
 plt.imshow(combined)
 plt.show()

Reconstrunction

I am sure with more training and fine tuning loss multipliers, you can get better results.

Lastly, decoder receives softmax of encoder output. This mean decoder tries to create output image from 0 - 1 probabilities of the input. So if the softmax probability vector is 0.98 at input location 0 and close to zero elsewhere, decoder should output an image that looks like 0.0. Here I give network input to create 0 to 9 reconstructions:

 test_arr = np.zeros([10, 10], dtype = np.float32)
 ind = np.arange(0, 10)
 test_arr[ind, ind] = 1.0
 
 model.eval()
 img = torch.from_numpy(test_arr).to(device)
 out = model.decoder(img)
 out = out[0:10, 0, :, :].squeeze().detach().cpu()
 out = out.permute(1,0,2).reshape(28, -1).numpy()
 plt.imshow(out)
 plt.show()

0 to 10 reconstruction

I've also done a few small changes in the code, printing epoch average loss etc. which doesn't really change the training logic, so you can see those changes in the code and let me know if anything looks weird.

yutasrobot
  • 2,356
  • 1
  • 17
  • 24
  • I'm doing this just for fun, I'm trying to see if I can get the encoder to do exactly digit classification as a "side effect" in the middle of the autoencoder. It seems to me that setting the encoded dimension to 10, that's the most optimal encoding, or at least one of the optimal-ish ones.. – Marton Trencseni Mar 18 '21 at 05:03
  • see my answer above: for mathematical reasons regarding the softmax, this wouldn't work, try Gumbal-Softmax with varying temperature and it might just work. Otherwise, there is a lot of work in the space of discrete variational autoencoders. – maggu Mar 18 '21 at 11:38
  • It can actually work, you can train the network with two losses added together: classification and reconstruction loss. See my edited answer. – yutasrobot Mar 19 '21 at 01:10
1

I was able to bring your code to a version where it would at least converge. In summary, I think there might be multiple problems with it: the normalization (why those values?), some unnecessary relus, too high learning rate, MSE loss instead of cross-entropy and mainly I don't think the softmax in the bottleneck layer works that way for vanishing gradient reasons, see here

https://www.quora.com/Does-anyone-ever-use-a-softmax-layer-mid-neural-network-rather-than-at-the-end

Maybe one could fix this using the Gumbel softmax: https://arxiv.org/abs/1611.01144

Moreover, there are papers already achieving this, but as a Variational Autoencoder rather than a vanilla autoencoder, see here: https://arxiv.org/abs/1609.02200.

For now you can use this modification, which at least converges and then modify step-by-step and see what breaks it.

As for the classification, the standard way would be to use the trained encoder to generate features from images and then use a normal classifier (SVG or so) on top of that.

batch_size = 16

transform = transforms.Compose([
    transforms.ToTensor(),
])
trainset = MNIST(root='./data/',  train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder,self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 2, kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(2, 4, kernel_size=5),
            )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4, 10, kernel_size=5),
            nn.ReLU(),
            nn.ConvTranspose2d(10, 1, kernel_size=5),
            nn.Sigmoid(),
            )
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = Autoencoder().cpu()
distance = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,weight_decay=1e-5)

num_epochs = 20

outputs = []
for epoch in tqdm(range(num_epochs)):
    for data in dataloader:
        img, _ = data
        img = Variable(img).cpu()
        output = model(img)
        loss = distance(output, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        outputs.append(output)
    print('epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))



import matplotlib.pyplot as plt
% plotting epoch outputs
for k in range(0, 20):
    plt.figure(figsize=(9, 2))
    imgs = outputs[k].detach().numpy()
    for i, item in enumerate(imgs):
        plt.imshow(item[0])
        plt.title(str(i))
        plt.show()
maggu
  • 101
  • 9
  • 1
    A note regarding the normalization values, these are actually the mean and std of the training data: https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457 – Shaido Mar 18 '21 at 03:07
  • Not sure why you needed to "get it to compile", I copy/pasted this from my ipython notebook. – Marton Trencseni Mar 18 '21 at 05:09
  • sorry, edited: i meant of course "converge" instead – maggu Mar 18 '21 at 11:27
  • to further improve convergence with sigmoidal output, you can improve by using cross-entropy loss – maggu Mar 18 '21 at 11:35
0

I played around with your code (from above and Github) and found the following:

  1. Sigmoid: when your code loads the MNIST dataset, you apply a Transform to normalize the data, but your Autoencoder model uses nn.Sigmoid() as its final layer, which forces the data to be in the range of [0, 1] (but the normalized data is more like [-.4242, 2.8215]. Commenting-out the sigmoid layer helps greatly reduce the loss during training.

  2. Softmax: I understand why you include the nn.Softmax() layer - to try and force the learned 10 features to be used sparsely for reconstructing each image. It does help raise the test accuracy in some cases. After trying a few ideas (like annealing a softmax temperature), it feels like a single float to reconstruct each class of digit is just insufficient.

  3. Clustering: another way to use the features to predict 1 of 10 positions for each image is by clustering the feature representations (over some set of training/dev samples). I tried this and found it help raised the test accuracy.

  4. CNN: I found a different CNN AE model from here that works a little bit better in the experiments I ran.

  5. Optimizer: I found that the Adam optimizer with LR=.001 works better that the ad-hoc values I tried with SGD, Adam, and Adadelta.

  6. Finally, I found that wrapping the img with Variable() is not needed, so I removed that.

Below is the final code I ended up with. After 25 epochs of training:

  • loss is around .03
  • test accuracy is around 60%

Training to 100 epochs doesn't seem to improve things. Here is a sample of the before/after digits at epoch 8 where the loss=.0436:

Before/after digits sample

import os
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

# # work around sklearn warning
os.environ["OMP_NUM_THREADS"] = "4"
import sklearn.cluster as cluster

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class Autoencoder(nn.Module):
    def __init__(self, d_hidden=10, use_softmax=False):
        super(Autoencoder,self).__init__()

        self.encoder = nn.Sequential(
            # 28 x 28
            nn.Conv2d(1, 4, kernel_size=5),
            # 4 x 24 x 24
            nn.ReLU(True),
            nn.Conv2d(4, 8, kernel_size=5),
            nn.ReLU(True),
            # 8 x 20 x 20 = 3200
            nn.Flatten(),
            nn.Linear(3200, d_hidden),
            # d_hidden
            #nn.Softmax(dim=-1),
            )

        self.use_softmax = use_softmax

        self.decoder = nn.Sequential(
            # d_hidden
            nn.Linear(d_hidden, 400),
            # 400
            nn.ReLU(True),
            nn.Linear(400, 4000),
            # 4000
            nn.ReLU(True),
            nn.Unflatten(1, (10, 20, 20)),
            # 10 x 20 x 20
            nn.ConvTranspose2d(10, 10, kernel_size=5),
            # 24 x 24
            nn.ConvTranspose2d(10, 1, kernel_size=5),
            # 28 x 28
            #nn.Sigmoid(),
            )

    def forward(self, x, temperature):
        features = self.encoder(x)

        if self.use_softmax:
            features = torch.softmax(features/temperature, dim=-1)

        output = self.decoder(features)
        return output

    def get_features(self, x):
        features = self.encoder(x)
        return features

class NewAutoencoder(nn.Module):
    def __init__(self, d_hidden=64, use_softmax=False):
        super(NewAutoencoder, self).__init__()

        self.encoder = nn.Sequential( # like the Composition layer you built
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, d_hidden, 7)
        )

        if use_softmax:
            self.encoder.add_module("softmax", nn.Softmax(dim=-1))

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(d_hidden, 32, 7),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
            #nn.Sigmoid()
        )

    def forward(self, x, temperature):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def get_features(self, x):
        output = self.encoder(x).squeeze(-1).squeeze(-1)
        return output

def show_samples(model, dataloader, device, epoch, temperature):
    model.eval()
    img, labels = list(dataloader)[0]
    img = img.to(device)
    output = model(img, temperature)
    inp = img[0:10, 0, :, :].squeeze().detach().cpu()
    out = output[0:10, 0, :, :].squeeze().detach().cpu()
    
    # Just some trick to concatenate first ten images next to each other
    inp = inp.permute(1,0,2).reshape(28, -1).numpy()
    out = out.permute(1,0,2).reshape(28, -1).numpy()
    combined = np.vstack([inp, out])
    
    plt.title("epoch: {}".format(epoch))

    plt.imshow(combined)
    plt.draw()
    plt.pause(0.1)

def calc_position_to_label_mapping(model, alignloader, device, quick, cluster_map):
    model.eval()
    position_counts_by_label = defaultdict(dict)   # key=label, value=dict(key=position, value-count)
    labels_by_position = {key:None for key in range(10)}

    # collect counts of each position, by label
    for images, labels in alignloader:
        images = images.to(device)
        output = model.get_features(images)

        if cluster_map:
            feature_data = output.cpu().detach().numpy()
            preds = cluster_map.predict(feature_data)
        else:
            preds = torch.argmax(output, dim=-1)

        for lab, pred in zip(labels, preds):
            label = int(lab)
            position = int(pred)
            pc = position_counts_by_label[label]

            if position in pc:
                pc[position] += 1
            else:
                pc[position] = 1

    '''
    Note: at this point, we could have more a particular position
    being the best predictor of more than 1 label.  Since each position can only
    predict a single label, we want to choose the overall assignments of position -> label
    that will maximize our accuracy.  The below algorithm estimates this best assignment:

        - normalize all counts by label
        - for the remaining labels/positions:
            - find position->label assigment with greatest difference between top scoring #1 assignment and #2 
              assignment within a label, across all labels
            - record that position->label assignment and remove position and assignment from pool.
        - repeat above step until all labels have been assigned a (unique) position.
    '''

    # normalize counts by label
    for label, pc in position_counts_by_label.items():
        total = sum(pc.values())
        for key in pc:
            pc[key] /= total

    # repeat until done
    remaining_positions = {key:1 for key in range(10)}

    while position_counts_by_label:
        # find strongest position/label assignment
        best = None

        for label, pc in position_counts_by_label.items():
            if len(pc) == 0:
                # no remaining positions predicted this label
                position = next(iter(remaining_positions))
                best = (label, position, 1)
                break

            if len(pc) == 1:
                # automatic winner
                best = (label, next(iter(pc)), 1)
                break

            pcx = dict(pc)
            key1 = max(pcx, key=pcx.get)
            del pcx[key1]
            key2 = max(pcx, key=pcx.get)
            
            diff = pc[key1] - pc[key2]
            #diff = pc[key1]

            if best is None or diff > best[2]:
                best = (label, key1, diff)
            
        # record chosen position/label
        label, position, score = best
        labels_by_position[position] = label

        # remove position/label from pool
        del position_counts_by_label[label]
        del remaining_positions[position]
        
        for pc in position_counts_by_label.values():
            if position in pc:
                del pc[position]

    return labels_by_position

def cluster_features(model, alignloader, device, quick):
    all_features = None

    for images, unused_labels in alignloader:

        images = images.to(device)
        features = model.get_features(images)

        if all_features is None:
            all_features = features
        else:
            all_features = torch.vstack( [all_features, features] )

    feature_data = all_features.cpu().detach().numpy()

    kmeans = cluster.KMeans(n_clusters=10)  
    kmeans.fit(feature_data)

    return kmeans

def eval_test(model, testloader, device, quick, labels_by_position, cluster_map):

    correct = 0
    samples = 0

    for images, labels in testloader:
        images = images.to(device)
        output = model.get_features(images)

        if cluster_map:
            feature_data = output.cpu().detach().numpy()
            preds = cluster_map.predict(feature_data)
        else:
            preds = torch.argmax(output, dim=-1)

        for lab, pred in zip(labels, preds):
            label = int(lab)
            position = int(pred)
            pred_label = labels_by_position[position]

            if label == pred_label:
                correct += 1
            samples += 1

    print("labels_by_position: {}".format(list(labels_by_position.values())))

    test_acc = correct/samples
    name = "Estimated" if quick else "Total"
    print("{} test acc: {:.4f} (samples: {:,})".format(name, test_acc, len(testloader.dataset)))

def evaluate(model, testloader, device, quick, use_clustering):
    
    if use_clustering:
        cluster_map = cluster_features(model, testloader, device, True)
        labels_by_position = calc_position_to_label_mapping(model, testloader, device, True, cluster_map)
        eval_test(model, testloader, device, True, labels_by_position, cluster_map)

    else:
        labels_by_position = calc_position_to_label_mapping(model, testloader, device, True, None)
        eval_test(model, testloader, device, True, labels_by_position, None)

def train():
    num_epochs = 25
    batch_size = 64
    test_samples = 1000

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    trainset = tv.datasets.MNIST(root='d:/.data/mnist',  train=True, download=True, transform=transform)
    testset = tv.datasets.MNIST(root='d:/.data/mnist',  train=False, download=True, transform=transform)

    indexes = list(range(test_samples))
    quick_testset = torch.utils.data.Subset(testset, indexes)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    quick_testloader = torch.utils.data.DataLoader(quick_testset, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

    device = torch.device("cuda")
    use_orig_ae = False
    use_softmax = False
    use_clustering = True
    d_hidden = 32 if use_clustering else 10
    plot_before_after_digits = True

    if use_orig_ae:
        model = Autoencoder(d_hidden=d_hidden, use_softmax=use_softmax).to(device)
    else:
        model = NewAutoencoder(d_hidden=d_hidden, use_softmax=use_softmax).to(device)

    temperature = 1

    distance = nn.MSELoss()
    #distance = nn.L1Loss()

    #optimizer = torch.optim.SGD(model.parameters(), lr=.05)  # , lr=.01, momentum=0.5)
    #optimizer = torch.optim.Adadelta(model.parameters(), lr=1)   
    optimizer = torch.optim.Adam(model.parameters(), lr=.001)   

    for epoch in range(num_epochs):
        model.train()

        for data in trainloader:
            optimizer.zero_grad()

            img, _ = data
            #img = Variable(img).to(device)
            img = img.to(device)

            output = model(img, temperature)

            loss = distance(output, img)
            loss.backward()
            optimizer.step()

        print('epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
        evaluate(model, quick_testloader, device, True, use_clustering)
        print()

        if plot_before_after_digits:
            show_samples(model, trainloader, device, epoch+1, temperature)

        temperature = .9*temperature

    # after training, do final (and full) eval
    evaluate(model, testloader, device, False, use_clustering)

    _ = input("hit RETURN to dismiss plot and end program")

if __name__ == "__main__":
    train()