0

I have been developing an LSTM Model for alanyzing EEG data and guessing the motion being performed within a certain time. I am using raw EEG data from https://archive.physionet.org/pn4/eegmmidb/ and I believe I am experiencing overfitting, or some sort of error in my code. It is attatched below, with the compile results.

    #Initializations of all of the online stuff we are using for the model
    #Imports
    import numpy as np
    import mne.io
    import os
    import mne
    import torch as th
    import torch.autograd as autograd
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    import braindecode
    #Froms
    from numpy.random import RandomState
    from tqdm.auto import tqdm
    from multiprocessing import cpu_count
    from sklearn.model_selection import train_test_split
    from sklearn import preprocessing
    from mne import Epochs, pick_types, events_from_annotations
    from mne.datasets import eegbci
    from mne.io import concatenate_raws, read_raw_edf
    from torch.utils.data import TensorDataset, DataLoader, Dataset
    from torch.optim.lr_scheduler import _LRScheduler
    from braindecode.torch_ext.schedulers import ScheduledOptimizer, CosineAnnealing
    from braindecode.datautil.signal_target import SignalAndTarget
    from braindecode.torch_ext.util import np_to_var, var_to_np
    from braindecode.torch_ext.optimizers import AdamW
    #Next is class initializations and variable initializaions
    #LSTMModel class initialization
    class LSTMModel(nn.Module):
        #n_features is the number of channels and num_classes is the number of actions that can be taken, seq_dim = 128
        def __init__(self, n_features, n_classes, n_hidden=60, n_layers=2, dropout=0.25):
            super().__init__()
            self.n_hidden = n_hidden
            self.n_layers = n_layers
            self.rnn1 = nn.LSTM(
                input_size = n_features, 
                hidden_size = n_hidden, 
                num_layers = n_layers,
                batch_first=True,
                dropout=dropout
            )
            self.rnn2 = nn.LSTM(
                input_size = n_hidden,
                hidden_size = n_hidden,
                num_layers = n_layers - 1,
                batch_first=True,
                dropout=dropout
            )
            self.dropout = nn.Dropout(dropout)
            self.classifier = nn.Linear(n_hidden, n_classes)
        def forward(self, x):
            h0, c0 = self.init_hidden(x)
            #Temp print statements for checking accuracy
            #print(x)
            #print(h0)
            #print(c0)
            out, (hn, cn) = self.rnn1(x, (h0, c0))
            out, (hn, cn) = self.rnn2(out)
            out = self.dropout(out[:, -1, :])
            out = self.classifier(out)
            return out
        def init_hidden(self, x):
            x = x.type(th.float64)
            h0 = th.zeros(self.n_layers, x.shape[0], self.n_hidden)
            c0 = th.zeros(self.n_layers, x.shape[0], self.n_hidden)
            return h0, c0
    tmin = -1
    tmax = 4
    N_EPOCHS = 250
    BATCH_SIZE = 128
    num_classes = 3
    hidden_size = 256
    num_layers = 2
    bs = 128
    lr = 0.0001
    physionet_paths = []
    timeTo = 16
    badList = []
    input_size = 801
    in_chans = 10
    #Data getting and cropping
    #Yes these first few lines are redundant but I want these here incase I ever want to do more than just S001's files
    for file in os.listdir("1.0.0"):
        if (file == "S001" or file == "S002" or file == "S003" or file == "S004"):
        #if (file[0] == "S" and file[1] == "0" and file != "S092" and file != "S088"):
            print(file)
            for file1 in os.listdir("1.0.0\\" + file):
                print(file1)
                check = file1[6]
                #if (check == "3" or check == "4" or check == "7" or check == "8" or check == "11" or check == "12") and file1[-1] == "f":
                #if file1 == "S001R03.edf" or file1 == "S001R04.edf" or file1 == "S001R07.edf" or file1 == "S001R08.edf" or file1 == "S001R11.edf" or file1 == "S001R12.edf":
                if file1[-1] == "f":
                    chToExclude = 64 - in_chans
                    chList = ['Af3.', 'Afz.', 'Af4.', 'Af8.', 'F7..', 'F5..', 'F3..', 'F1..', 'Fz..', 'F2..', 'F4..', 'F6..', 'F8..', 'Ft7.', 'Ft8.', 'T7..', 'T8..', 'T9..', 'T10.', 'Tp7.', 'Tp8.', 'P7..', 'P5..', 'P3..', 'P1..', 'Pz..', 'P2..', 'P4..', 'P6..', 'P8..', 'Po7.', 'Po3.', 'Poz.', 'Po4.', 'Po8.', 'O1..', 'Oz..', 'O2..', 'Iz..', 'Fpz.', 'Fp2.', 'Af7.', 'Cp2.', 'Cp4.', 'Cp6.', 'Fp1.', 'Cp1.', 'Fc2.', 'Fc4.', 'Fc6.', 'C5..', 'C3..', 'C1..', 'Cz..', 'C2..', 'C4..', 'C6..', 'Cp5.', 'Cp3.', 'Fcz.', 'Fc3.', 'Fc1.', 'Fc5.', 'Cpz']
                    physionet_paths.append("1.0.0\\" + file + "\\" + file1)
    #Now on to combining all of the data into a raw file, and cropping it
    for i in range(0, chToExclude):
        badList.append(chList[i])
    event_id = dict(still=1, left=2, right=3) #Add/remove left=2, right=3)
    parts = [read_raw_edf(path, preload=True,stim_channel='auto', exclude=badList)
            for path in physionet_paths]
    for part in parts:
        print(part)
        print(part.info['sfreq'])
        #part.crop(0, timeTo, include_tmax=True)
    print(parts)
    raw = concatenate_raws(parts)
    print(raw)
    picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                    exclude='bads')
    #Now we initialize the events as well as the Epochs we are going to be using for our data.
    events, _ = events_from_annotations(raw, event_id=dict(T0=1, T1=2, T2=3))
    epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True)
    epochs_train = epochs.copy()
    #Split data into numbers and labels respectively, then split them into training and validating sets randomly
    origX = (epochs_train.get_data())
    np.swapaxes(origX, 1, 2)
    y = (epochs_train.events[:, 2] - 2)
    train_X, valid_X, train_y, valid_y = train_test_split(origX, y, test_size=0.2, random_state=32)
    in_time_len = 801
    #Normalizing the data
    """X_mean = train_X.mean(axis=0)
    valX_mean = valid_X.mean(axis=0)
    X_std = train_X.std(axis=0)
    valX_std = valid_X.std(axis=0)
    train_X = (train_X - X_mean) / X_std
    valid_X = (valid_X - valX_mean) / valX_std
    """
    normalizer = preprocessing.Normalizer(norm='l2')

    # Reshape the 3D data to 2D
    train_X_2d = train_X.reshape(train_X.shape[0], -1)
    valid_X_2d = valid_X.reshape(valid_X.shape[0], -1)

    # Normalize the 2D data
    train_X_normalized = normalizer.transform(train_X_2d)
    valid_X_normalized = normalizer.transform(valid_X_2d)

    # Reshape the normalized data back to 3D
    train_X_normalized = train_X_normalized.reshape(train_X.shape)
    valid_X_normalized = valid_X_normalized.reshape(valid_X.shape)
    train_X = train_X_normalized
    valid_X = valid_X_normalized

    #Creating our working data sets
    train_ds = SignalAndTarget(train_X, train_y)
    valid_ds = SignalAndTarget(valid_X, valid_y)
    #Next we initialize the model.
    model = LSTMModel(input_size, num_classes, hidden_size, num_layers)
    #Create a test input so we can get what our output size will be
    test_input = np_to_var(np.ones((origX.shape[0], in_chans, in_time_len), dtype=np.float32))
    #Initialize the optimizer, iterator, loss function, and some other helpful variables we will use later.
    opt = AdamW(model.parameters(), lr=lr, weight_decay=0)
    out = model(test_input)
    n_preds_per_input = out.cpu().data.numpy().shape[1]
    iterator = CropsFromTrialsIterator(batch_size=BATCH_SIZE,input_time_length=in_time_len,
                                    n_preds_per_input=n_preds_per_input)
    n_updates_per_epoch = len([None for b in iterator.get_batches(train_ds, shuffle=True)])
    criterion = nn.CrossEntropyLoss()
    #Begin training loop
    best_accuracy = 0
    for epoch in range(1, N_EPOCHS + 1):
        model = model.float()
        for (x_batch, y_batch) in iterator.get_batches(train_ds, shuffle=True):
            model.train()
            x_batch = np_to_var(x_batch, dtype=float)
            x_batch = x_batch.squeeze()
            y_batch = np.array(y_batch) + 1.0
            opt.zero_grad()
            tOut = model(x_batch.float())
            #out = th.mean(out, dim=1, keepdim=False) I think keep this out
            y_batch = th.from_numpy(y_batch)
            y_batch = y_batch.type(th.LongTensor)
            tLoss = criterion(tOut, y_batch)
            tLoss.backward()
            opt.step()
            tPreds = th.argmax(tOut, dim=1)
            correct = sum(1 for a, b in zip(tPreds, y_batch) if a == b)
        accuracy = correct / len(y_batch)
        print("Training Epoch: " + str(epoch) + ". Loss: " + str(tLoss.item()) + ". Acc.: " + str(accuracy * 100))
        correct, total = 0, 0
        model.eval()
        #Validation loop
        for (val_X, val_Y) in iterator.get_batches(valid_ds, shuffle=True):
            val_X = np_to_var(val_X, dtype=np.double)
            val_Y = val_Y + 1
            val_Y = np_to_var(val_Y)
            val_X = val_X.squeeze()
            val_Y = val_Y.type(th.LongTensor)
            ypred = model(val_X.float())
            loss = criterion(ypred, val_Y)
            preds = th.argmax(ypred, dim=1)
            correct1 = sum(1 for a, b in zip(preds, val_Y) if a == b)
        acc = correct1 / len(val_Y)
        print("Validation Epoch: " + str(epoch) + ". Loss: " + str(loss.item()) + ". Acc.: " + str(acc * 100))
        #If the model is the new best one save it
        if acc > best_accuracy:
            best_accuracy = acc
            th.save(model.state_dict(), 'best.pth')
            print("New best model")

Training Epoch: 1. Loss: 1.0954054594039917. Acc.: 44.71544715447154 Validation Epoch: 1. Loss: 1.0879806280136108. Acc.: 53.956834532374096 New best model Training Epoch: 2. Loss: 1.079361081123352. Acc.: 48.78048780487805 Validation Epoch: 2. Loss: 1.083898663520813. Acc.: 44.60431654676259 Training Epoch: 3. Loss: 1.07877779006958. Acc.: 42.27642276422765 Validation Epoch: 3. Loss: 1.0489959716796875. Acc.: 53.23741007194245 Training Epoch: 4. Loss: 1.0769835710525513. Acc.: 43.08943089430895 Validation Epoch: 4. Loss: 1.019529104232788. Acc.: 53.956834532374096 Training Epoch: 5. Loss: 1.0317955017089844. Acc.: 50.40650406504065 Validation Epoch: 5. Loss: 1.0308650732040405. Acc.: 50.35971223021583 Training Epoch: 6. Loss: 1.0240803956985474. Acc.: 51.21951219512195 Validation Epoch: 6. Loss: 1.022445559501648. Acc.: 51.798561151079134 Training Epoch: 7. Loss: 1.033755898475647. Acc.: 48.78048780487805 Validation Epoch: 7. Loss: 1.0355108976364136. Acc.: 48.92086330935252 Training Epoch: 8. Loss: 1.0010309219360352. Acc.: 52.84552845528455 Validation Epoch: 8. Loss: 1.0161426067352295. Acc.: 51.07913669064749 Training Epoch: 9. Loss: 0.9809502363204956. Acc.: 54.47154471544715 Validation Epoch: 9. Loss: 0.992912232875824. Acc.: 52.51798561151079 Training Epoch: 10. Loss: 1.0562916994094849. Acc.: 45.52845528455284 Validation Epoch: 10. Loss: 0.951147198677063. Acc.: 52.51798561151079 Training Epoch: 11. Loss: 1.0033303499221802. Acc.: 50.40650406504065 Validation Epoch: 11. Loss: 1.0053869485855103. Acc.: 46.043165467625904 Training Epoch: 12. Loss: 1.0156131982803345. Acc.: 50.40650406504065 Validation Epoch: 12. Loss: 0.9308319091796875. Acc.: 51.798561151079134 Training Epoch: 13. Loss: 1.0211182832717896. Acc.: 46.34146341463415 Validation Epoch: 13. Loss: 0.9038681983947754. Acc.: 59.71223021582733 New best model Training Epoch: 14. Loss: 1.020182728767395. Acc.: 45.52845528455284 Validation Epoch: 14. Loss: 0.9128692746162415. Acc.: 57.55395683453237 Training Epoch: 15. Loss: 0.9860080480575562. Acc.: 54.47154471544715 Validation Epoch: 15. Loss: 0.9508476257324219. Acc.: 53.23741007194245 Training Epoch: 16. Loss: 0.980677604675293. Acc.: 51.21951219512195 Validation Epoch: 16. Loss: 0.9205264449119568. Acc.: 61.87050359712231 New best model Training Epoch: 17. Loss: 0.9726919531822205. Acc.: 51.21951219512195 Validation Epoch: 17. Loss: 0.9313606023788452. Acc.: 58.27338129496403 Training Epoch: 18. Loss: 0.932588517665863. Acc.: 56.91056910569105 Validation Epoch: 18. Loss: 0.9397252798080444. Acc.: 51.798561151079134 Training Epoch: 19. Loss: 0.9076926112174988. Acc.: 55.28455284552846 Validation Epoch: 19. Loss: 0.9691327810287476. Acc.: 53.956834532374096 Training Epoch: 20. Loss: 0.9545367360115051. Acc.: 56.91056910569105 Validation Epoch: 20. Loss: 0.8751280903816223. Acc.: 61.87050359712231 Training Epoch: 21. Loss: 0.8569953441619873. Acc.: 64.22764227642277 Validation Epoch: 21. Loss: 0.9191533923149109. Acc.: 58.992805755395686 Training Epoch: 22. Loss: 0.9185190796852112. Acc.: 56.91056910569105 Validation Epoch: 22. Loss: 0.9195245504379272. Acc.: 60.431654676258994 Training Epoch: 23. Loss: 0.9557067155838013. Acc.: 54.47154471544715 Validation Epoch: 23. Loss: 0.9277586340904236. Acc.: 54.67625899280576 Training Epoch: 24. Loss: 0.9821110367774963. Acc.: 49.59349593495935 Validation Epoch: 24. Loss: 0.9196327924728394. Acc.: 58.992805755395686 Training Epoch: 25. Loss: 0.8572725653648376. Acc.: 62.601626016260155 Validation Epoch: 25. Loss: 0.9899344444274902. Acc.: 55.39568345323741 Training Epoch: 26. Loss: 0.9966101050376892. Acc.: 50.40650406504065 Validation Epoch: 26. Loss: 0.9025679230690002. Acc.: 61.15107913669065 Training Epoch: 27. Loss: 0.9764003157615662. Acc.: 53.65853658536586 Validation Epoch: 27. Loss: 0.8876821994781494. Acc.: 58.27338129496403 Training Epoch: 28. Loss: 0.917256236076355. Acc.: 54.47154471544715 Validation Epoch: 28. Loss: 0.9570827484130859. Acc.: 56.11510791366906 Training Epoch: 29. Loss: 0.8812904953956604. Acc.: 59.34959349593496 Validation Epoch: 29. Loss: 0.9196577072143555. Acc.: 59.71223021582733 Training Epoch: 30. Loss: 0.8500334024429321. Acc.: 66.66666666666666 Validation Epoch: 30. Loss: 0.8992879390716553. Acc.: 59.71223021582733 Training Epoch: 31. Loss: 0.8446879982948303. Acc.: 63.41463414634146 Validation Epoch: 31. Loss: 0.9848458170890808. Acc.: 56.83453237410072 Training Epoch: 32. Loss: 0.8742657899856567. Acc.: 56.09756097560976 Validation Epoch: 32. Loss: 0.912958562374115. Acc.: 61.87050359712231 Training Epoch: 33. Loss: 0.8341238498687744. Acc.: 62.601626016260155 Validation Epoch: 33. Loss: 1.0351901054382324. Acc.: 51.798561151079134 Training Epoch: 34. Loss: 0.9610933065414429. Acc.: 52.03252032520326 Validation Epoch: 34. Loss: 0.9500865340232849. Acc.: 56.83453237410072 Training Epoch: 35. Loss: 0.798149049282074. Acc.: 59.34959349593496 Validation Epoch: 35. Loss: 0.9144001603126526. Acc.: 59.71223021582733 Training Epoch: 36. Loss: 0.8537905812263489. Acc.: 60.16260162601627 Validation Epoch: 36. Loss: 0.952705979347229. Acc.: 56.11510791366906 Training Epoch: 37. Loss: 0.853546679019928. Acc.: 59.34959349593496 Validation Epoch: 37. Loss: 0.8856449127197266. Acc.: 56.11510791366906 Training Epoch: 38. Loss: 0.7937348484992981. Acc.: 63.41463414634146 Validation Epoch: 38. Loss: 0.9971412420272827. Acc.: 54.67625899280576 Training Epoch: 39. Loss: 0.83415687084198. Acc.: 56.91056910569105 Validation Epoch: 39. Loss: 0.975570797920227. Acc.: 54.67625899280576 ....... Validation Epoch: 190. Loss: 2.3696634769439697. Acc.: 51.798561151079134 Training Epoch: 191. Loss: 0.24480554461479187. Acc.: 91.05691056910568 Validation Epoch: 191. Loss: 2.2647438049316406. Acc.: 51.798561151079134 Training Epoch: 192. Loss: 0.17438578605651855. Acc.: 94.3089430894309 Validation Epoch: 192. Loss: 2.4518003463745117. Acc.: 51.798561151079134 Training Epoch: 193. Loss: 0.18941359221935272. Acc.: 93.4959349593496 Here are the results, as you can see the training seems to work but the validation does not, and the loss starts increasing once it gets up to like 30 epochs.

Please let me know what I am doing wrong and if this is actually overfitting or not. Thanks!

  • An easy way to detect overfitting is by plotting 2 loss curves over the epochs, one for training loss and other for validation loss. If the training loss is decreasing consistently, but the validation loss is not decreasing as consistently, then you can conclude that your model is overfitting. It is difficult to assess that accurately by just looking at raw numbers. – Charudatta Manwatkar Jul 05 '23 at 19:09

0 Answers0