I have been developing an LSTM Model for alanyzing EEG data and guessing the motion being performed within a certain time. I am using raw EEG data from https://archive.physionet.org/pn4/eegmmidb/ and I believe I am experiencing overfitting, or some sort of error in my code. It is attatched below, with the compile results.
#Initializations of all of the online stuff we are using for the model
#Imports
import numpy as np
import mne.io
import os
import mne
import torch as th
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import braindecode
#Froms
from numpy.random import RandomState
from tqdm.auto import tqdm
from multiprocessing import cpu_count
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from mne import Epochs, pick_types, events_from_annotations
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.optim.lr_scheduler import _LRScheduler
from braindecode.torch_ext.schedulers import ScheduledOptimizer, CosineAnnealing
from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.torch_ext.util import np_to_var, var_to_np
from braindecode.torch_ext.optimizers import AdamW
#Next is class initializations and variable initializaions
#LSTMModel class initialization
class LSTMModel(nn.Module):
#n_features is the number of channels and num_classes is the number of actions that can be taken, seq_dim = 128
def __init__(self, n_features, n_classes, n_hidden=60, n_layers=2, dropout=0.25):
super().__init__()
self.n_hidden = n_hidden
self.n_layers = n_layers
self.rnn1 = nn.LSTM(
input_size = n_features,
hidden_size = n_hidden,
num_layers = n_layers,
batch_first=True,
dropout=dropout
)
self.rnn2 = nn.LSTM(
input_size = n_hidden,
hidden_size = n_hidden,
num_layers = n_layers - 1,
batch_first=True,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(n_hidden, n_classes)
def forward(self, x):
h0, c0 = self.init_hidden(x)
#Temp print statements for checking accuracy
#print(x)
#print(h0)
#print(c0)
out, (hn, cn) = self.rnn1(x, (h0, c0))
out, (hn, cn) = self.rnn2(out)
out = self.dropout(out[:, -1, :])
out = self.classifier(out)
return out
def init_hidden(self, x):
x = x.type(th.float64)
h0 = th.zeros(self.n_layers, x.shape[0], self.n_hidden)
c0 = th.zeros(self.n_layers, x.shape[0], self.n_hidden)
return h0, c0
tmin = -1
tmax = 4
N_EPOCHS = 250
BATCH_SIZE = 128
num_classes = 3
hidden_size = 256
num_layers = 2
bs = 128
lr = 0.0001
physionet_paths = []
timeTo = 16
badList = []
input_size = 801
in_chans = 10
#Data getting and cropping
#Yes these first few lines are redundant but I want these here incase I ever want to do more than just S001's files
for file in os.listdir("1.0.0"):
if (file == "S001" or file == "S002" or file == "S003" or file == "S004"):
#if (file[0] == "S" and file[1] == "0" and file != "S092" and file != "S088"):
print(file)
for file1 in os.listdir("1.0.0\\" + file):
print(file1)
check = file1[6]
#if (check == "3" or check == "4" or check == "7" or check == "8" or check == "11" or check == "12") and file1[-1] == "f":
#if file1 == "S001R03.edf" or file1 == "S001R04.edf" or file1 == "S001R07.edf" or file1 == "S001R08.edf" or file1 == "S001R11.edf" or file1 == "S001R12.edf":
if file1[-1] == "f":
chToExclude = 64 - in_chans
chList = ['Af3.', 'Afz.', 'Af4.', 'Af8.', 'F7..', 'F5..', 'F3..', 'F1..', 'Fz..', 'F2..', 'F4..', 'F6..', 'F8..', 'Ft7.', 'Ft8.', 'T7..', 'T8..', 'T9..', 'T10.', 'Tp7.', 'Tp8.', 'P7..', 'P5..', 'P3..', 'P1..', 'Pz..', 'P2..', 'P4..', 'P6..', 'P8..', 'Po7.', 'Po3.', 'Poz.', 'Po4.', 'Po8.', 'O1..', 'Oz..', 'O2..', 'Iz..', 'Fpz.', 'Fp2.', 'Af7.', 'Cp2.', 'Cp4.', 'Cp6.', 'Fp1.', 'Cp1.', 'Fc2.', 'Fc4.', 'Fc6.', 'C5..', 'C3..', 'C1..', 'Cz..', 'C2..', 'C4..', 'C6..', 'Cp5.', 'Cp3.', 'Fcz.', 'Fc3.', 'Fc1.', 'Fc5.', 'Cpz']
physionet_paths.append("1.0.0\\" + file + "\\" + file1)
#Now on to combining all of the data into a raw file, and cropping it
for i in range(0, chToExclude):
badList.append(chList[i])
event_id = dict(still=1, left=2, right=3) #Add/remove left=2, right=3)
parts = [read_raw_edf(path, preload=True,stim_channel='auto', exclude=badList)
for path in physionet_paths]
for part in parts:
print(part)
print(part.info['sfreq'])
#part.crop(0, timeTo, include_tmax=True)
print(parts)
raw = concatenate_raws(parts)
print(raw)
picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
exclude='bads')
#Now we initialize the events as well as the Epochs we are going to be using for our data.
events, _ = events_from_annotations(raw, event_id=dict(T0=1, T1=2, T2=3))
epochs = Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks, baseline=None, preload=True)
epochs_train = epochs.copy()
#Split data into numbers and labels respectively, then split them into training and validating sets randomly
origX = (epochs_train.get_data())
np.swapaxes(origX, 1, 2)
y = (epochs_train.events[:, 2] - 2)
train_X, valid_X, train_y, valid_y = train_test_split(origX, y, test_size=0.2, random_state=32)
in_time_len = 801
#Normalizing the data
"""X_mean = train_X.mean(axis=0)
valX_mean = valid_X.mean(axis=0)
X_std = train_X.std(axis=0)
valX_std = valid_X.std(axis=0)
train_X = (train_X - X_mean) / X_std
valid_X = (valid_X - valX_mean) / valX_std
"""
normalizer = preprocessing.Normalizer(norm='l2')
# Reshape the 3D data to 2D
train_X_2d = train_X.reshape(train_X.shape[0], -1)
valid_X_2d = valid_X.reshape(valid_X.shape[0], -1)
# Normalize the 2D data
train_X_normalized = normalizer.transform(train_X_2d)
valid_X_normalized = normalizer.transform(valid_X_2d)
# Reshape the normalized data back to 3D
train_X_normalized = train_X_normalized.reshape(train_X.shape)
valid_X_normalized = valid_X_normalized.reshape(valid_X.shape)
train_X = train_X_normalized
valid_X = valid_X_normalized
#Creating our working data sets
train_ds = SignalAndTarget(train_X, train_y)
valid_ds = SignalAndTarget(valid_X, valid_y)
#Next we initialize the model.
model = LSTMModel(input_size, num_classes, hidden_size, num_layers)
#Create a test input so we can get what our output size will be
test_input = np_to_var(np.ones((origX.shape[0], in_chans, in_time_len), dtype=np.float32))
#Initialize the optimizer, iterator, loss function, and some other helpful variables we will use later.
opt = AdamW(model.parameters(), lr=lr, weight_decay=0)
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[1]
iterator = CropsFromTrialsIterator(batch_size=BATCH_SIZE,input_time_length=in_time_len,
n_preds_per_input=n_preds_per_input)
n_updates_per_epoch = len([None for b in iterator.get_batches(train_ds, shuffle=True)])
criterion = nn.CrossEntropyLoss()
#Begin training loop
best_accuracy = 0
for epoch in range(1, N_EPOCHS + 1):
model = model.float()
for (x_batch, y_batch) in iterator.get_batches(train_ds, shuffle=True):
model.train()
x_batch = np_to_var(x_batch, dtype=float)
x_batch = x_batch.squeeze()
y_batch = np.array(y_batch) + 1.0
opt.zero_grad()
tOut = model(x_batch.float())
#out = th.mean(out, dim=1, keepdim=False) I think keep this out
y_batch = th.from_numpy(y_batch)
y_batch = y_batch.type(th.LongTensor)
tLoss = criterion(tOut, y_batch)
tLoss.backward()
opt.step()
tPreds = th.argmax(tOut, dim=1)
correct = sum(1 for a, b in zip(tPreds, y_batch) if a == b)
accuracy = correct / len(y_batch)
print("Training Epoch: " + str(epoch) + ". Loss: " + str(tLoss.item()) + ". Acc.: " + str(accuracy * 100))
correct, total = 0, 0
model.eval()
#Validation loop
for (val_X, val_Y) in iterator.get_batches(valid_ds, shuffle=True):
val_X = np_to_var(val_X, dtype=np.double)
val_Y = val_Y + 1
val_Y = np_to_var(val_Y)
val_X = val_X.squeeze()
val_Y = val_Y.type(th.LongTensor)
ypred = model(val_X.float())
loss = criterion(ypred, val_Y)
preds = th.argmax(ypred, dim=1)
correct1 = sum(1 for a, b in zip(preds, val_Y) if a == b)
acc = correct1 / len(val_Y)
print("Validation Epoch: " + str(epoch) + ". Loss: " + str(loss.item()) + ". Acc.: " + str(acc * 100))
#If the model is the new best one save it
if acc > best_accuracy:
best_accuracy = acc
th.save(model.state_dict(), 'best.pth')
print("New best model")
Training Epoch: 1. Loss: 1.0954054594039917. Acc.: 44.71544715447154 Validation Epoch: 1. Loss: 1.0879806280136108. Acc.: 53.956834532374096 New best model Training Epoch: 2. Loss: 1.079361081123352. Acc.: 48.78048780487805 Validation Epoch: 2. Loss: 1.083898663520813. Acc.: 44.60431654676259 Training Epoch: 3. Loss: 1.07877779006958. Acc.: 42.27642276422765 Validation Epoch: 3. Loss: 1.0489959716796875. Acc.: 53.23741007194245 Training Epoch: 4. Loss: 1.0769835710525513. Acc.: 43.08943089430895 Validation Epoch: 4. Loss: 1.019529104232788. Acc.: 53.956834532374096 Training Epoch: 5. Loss: 1.0317955017089844. Acc.: 50.40650406504065 Validation Epoch: 5. Loss: 1.0308650732040405. Acc.: 50.35971223021583 Training Epoch: 6. Loss: 1.0240803956985474. Acc.: 51.21951219512195 Validation Epoch: 6. Loss: 1.022445559501648. Acc.: 51.798561151079134 Training Epoch: 7. Loss: 1.033755898475647. Acc.: 48.78048780487805 Validation Epoch: 7. Loss: 1.0355108976364136. Acc.: 48.92086330935252 Training Epoch: 8. Loss: 1.0010309219360352. Acc.: 52.84552845528455 Validation Epoch: 8. Loss: 1.0161426067352295. Acc.: 51.07913669064749 Training Epoch: 9. Loss: 0.9809502363204956. Acc.: 54.47154471544715 Validation Epoch: 9. Loss: 0.992912232875824. Acc.: 52.51798561151079 Training Epoch: 10. Loss: 1.0562916994094849. Acc.: 45.52845528455284 Validation Epoch: 10. Loss: 0.951147198677063. Acc.: 52.51798561151079 Training Epoch: 11. Loss: 1.0033303499221802. Acc.: 50.40650406504065 Validation Epoch: 11. Loss: 1.0053869485855103. Acc.: 46.043165467625904 Training Epoch: 12. Loss: 1.0156131982803345. Acc.: 50.40650406504065 Validation Epoch: 12. Loss: 0.9308319091796875. Acc.: 51.798561151079134 Training Epoch: 13. Loss: 1.0211182832717896. Acc.: 46.34146341463415 Validation Epoch: 13. Loss: 0.9038681983947754. Acc.: 59.71223021582733 New best model Training Epoch: 14. Loss: 1.020182728767395. Acc.: 45.52845528455284 Validation Epoch: 14. Loss: 0.9128692746162415. Acc.: 57.55395683453237 Training Epoch: 15. Loss: 0.9860080480575562. Acc.: 54.47154471544715 Validation Epoch: 15. Loss: 0.9508476257324219. Acc.: 53.23741007194245 Training Epoch: 16. Loss: 0.980677604675293. Acc.: 51.21951219512195 Validation Epoch: 16. Loss: 0.9205264449119568. Acc.: 61.87050359712231 New best model Training Epoch: 17. Loss: 0.9726919531822205. Acc.: 51.21951219512195 Validation Epoch: 17. Loss: 0.9313606023788452. Acc.: 58.27338129496403 Training Epoch: 18. Loss: 0.932588517665863. Acc.: 56.91056910569105 Validation Epoch: 18. Loss: 0.9397252798080444. Acc.: 51.798561151079134 Training Epoch: 19. Loss: 0.9076926112174988. Acc.: 55.28455284552846 Validation Epoch: 19. Loss: 0.9691327810287476. Acc.: 53.956834532374096 Training Epoch: 20. Loss: 0.9545367360115051. Acc.: 56.91056910569105 Validation Epoch: 20. Loss: 0.8751280903816223. Acc.: 61.87050359712231 Training Epoch: 21. Loss: 0.8569953441619873. Acc.: 64.22764227642277 Validation Epoch: 21. Loss: 0.9191533923149109. Acc.: 58.992805755395686 Training Epoch: 22. Loss: 0.9185190796852112. Acc.: 56.91056910569105 Validation Epoch: 22. Loss: 0.9195245504379272. Acc.: 60.431654676258994 Training Epoch: 23. Loss: 0.9557067155838013. Acc.: 54.47154471544715 Validation Epoch: 23. Loss: 0.9277586340904236. Acc.: 54.67625899280576 Training Epoch: 24. Loss: 0.9821110367774963. Acc.: 49.59349593495935 Validation Epoch: 24. Loss: 0.9196327924728394. Acc.: 58.992805755395686 Training Epoch: 25. Loss: 0.8572725653648376. Acc.: 62.601626016260155 Validation Epoch: 25. Loss: 0.9899344444274902. Acc.: 55.39568345323741 Training Epoch: 26. Loss: 0.9966101050376892. Acc.: 50.40650406504065 Validation Epoch: 26. Loss: 0.9025679230690002. Acc.: 61.15107913669065 Training Epoch: 27. Loss: 0.9764003157615662. Acc.: 53.65853658536586 Validation Epoch: 27. Loss: 0.8876821994781494. Acc.: 58.27338129496403 Training Epoch: 28. Loss: 0.917256236076355. Acc.: 54.47154471544715 Validation Epoch: 28. Loss: 0.9570827484130859. Acc.: 56.11510791366906 Training Epoch: 29. Loss: 0.8812904953956604. Acc.: 59.34959349593496 Validation Epoch: 29. Loss: 0.9196577072143555. Acc.: 59.71223021582733 Training Epoch: 30. Loss: 0.8500334024429321. Acc.: 66.66666666666666 Validation Epoch: 30. Loss: 0.8992879390716553. Acc.: 59.71223021582733 Training Epoch: 31. Loss: 0.8446879982948303. Acc.: 63.41463414634146 Validation Epoch: 31. Loss: 0.9848458170890808. Acc.: 56.83453237410072 Training Epoch: 32. Loss: 0.8742657899856567. Acc.: 56.09756097560976 Validation Epoch: 32. Loss: 0.912958562374115. Acc.: 61.87050359712231 Training Epoch: 33. Loss: 0.8341238498687744. Acc.: 62.601626016260155 Validation Epoch: 33. Loss: 1.0351901054382324. Acc.: 51.798561151079134 Training Epoch: 34. Loss: 0.9610933065414429. Acc.: 52.03252032520326 Validation Epoch: 34. Loss: 0.9500865340232849. Acc.: 56.83453237410072 Training Epoch: 35. Loss: 0.798149049282074. Acc.: 59.34959349593496 Validation Epoch: 35. Loss: 0.9144001603126526. Acc.: 59.71223021582733 Training Epoch: 36. Loss: 0.8537905812263489. Acc.: 60.16260162601627 Validation Epoch: 36. Loss: 0.952705979347229. Acc.: 56.11510791366906 Training Epoch: 37. Loss: 0.853546679019928. Acc.: 59.34959349593496 Validation Epoch: 37. Loss: 0.8856449127197266. Acc.: 56.11510791366906 Training Epoch: 38. Loss: 0.7937348484992981. Acc.: 63.41463414634146 Validation Epoch: 38. Loss: 0.9971412420272827. Acc.: 54.67625899280576 Training Epoch: 39. Loss: 0.83415687084198. Acc.: 56.91056910569105 Validation Epoch: 39. Loss: 0.975570797920227. Acc.: 54.67625899280576 ....... Validation Epoch: 190. Loss: 2.3696634769439697. Acc.: 51.798561151079134 Training Epoch: 191. Loss: 0.24480554461479187. Acc.: 91.05691056910568 Validation Epoch: 191. Loss: 2.2647438049316406. Acc.: 51.798561151079134 Training Epoch: 192. Loss: 0.17438578605651855. Acc.: 94.3089430894309 Validation Epoch: 192. Loss: 2.4518003463745117. Acc.: 51.798561151079134 Training Epoch: 193. Loss: 0.18941359221935272. Acc.: 93.4959349593496 Here are the results, as you can see the training seems to work but the validation does not, and the loss starts increasing once it gets up to like 30 epochs.
Please let me know what I am doing wrong and if this is actually overfitting or not. Thanks!