I am fairly new to machine learning and am currently trying to build a simple feedforward neural network on severely imbalanced data. The data consists of 64 different variables (all normalized) and 1 binary variable (1 and 0) which the nn is supposed to predict. The data consists of 43405 datarows with 2091 having class 1 and 41314 having class 0. A high prediction accuracy of class 1 is the target.
I am actually not really sure what is happening but to me it seems that the nn is not learning on the class 1 data (which is the important one). Before I implemented sample weights (I was not able to implement class weights) the output showed an overall accuracy of always >93% after implementation of sample weights (the height does not significantly change anything) the overall accuracy dropped to around 40% but the class 1 accuracy stayed extremely low. Changing the learning rate does not change anything. So does changing the architecture.
I have no idea what I am doing wrong or how I could solve this problem. As this fairly important for my thesis I would be extremely happy for any kind of help!!! If anything is missing in my description here or I was not clear enough pls ask!
My code looks as followed:
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
from keras.callbacks import EarlyStopping
from keras import backend as K
from sklearn.model_selection import train_test_split
from keras.utils import np_utils
from sklearn.utils import class_weight
from sklearn.preprocessing import LabelEncoder
import matplotlib
import keras
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
df = pd.read_pickle('nfinaldf.pkl')
df = df.drop(columns = ['index'])
x = df.drop(columns = ['status'])
y = to_categorical(df.status)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.33, random_state = 50)
classes = np.unique(y_train)
model = Sequential()
n_cols = x_train.shape[1]
model.add(Dense(60, activation='relu', input_shape=(n_cols,)))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(2, activation = 'softmax'))
def generate_sample_weights(training_data, class_weight_dictionary):
sample_weights = [class_weight_dictionary[np.where(one_hot_row == 1)[0][0]] for one_hot_row in training_data]
return np.asarray(sample_weights)
class_weights_dict = { 0 : 1, 1 : 50}
optimizer = keras.optimizers.Adam(lr=0.0001)
def sensitivity(y_true, y_pred):
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
return true_positives / (possible_positives + K.epsilon())
def specificity(y_true, y_pred):
true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
return true_negatives / (possible_negatives + K.epsilon())
INTERESTING_CLASS_ID = 1
def single_class_accuracy(y_true, y_pred):
class_id_true = K.argmax(y_true, axis=-1)
class_id_preds = K.argmax(y_pred, axis=-1)
accuracy_mask = K.cast(K.equal(class_id_preds, INTERESTING_CLASS_ID), 'int32')
class_acc_tensor = K.cast(K.equal(class_id_true, class_id_preds), 'int32') * accuracy_mask
class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask), 1)
return class_acc
model.compile(optimizer = optimizer, loss = 'binary_crossentropy', metrics=[sensitivity, specificity,'accuracy', single_class_accuracy])
early_stopping_monitor = EarlyStopping(patience = 30)
history = model.fit(x_train, y_train, validation_data = (x_test, y_test), epochs = 80, callbacks=[early_stopping_monitor], sample_weight = generate_sample_weights(y_train, class_weights_dict))
score = model.evaluate(x_test, y_test, verbose = 0)
fig, axs = plt.subplots(4)
fig.suptitle('Vertically stacked subplots')
axs[0].plot(history.history['sensitivity'], label='sensitivity (training data)')
axs[0].plot(history.history['specificity'], label='specificity (validation data)')
axs[0].legend(loc="upper left")
axs[1].plot(history.history['val_sensitivity'], label='val_sensitivity (training data)')
axs[1].plot(history.history['val_specificity'], label='val_specificity (validation data)')
axs[1].legend(loc="upper left")
axs[2].plot(history.history['loss'], label='loss (training data)')
axs[2].legend(loc="upper left")
axs[3].plot(history.history['single_class_accuracy'], label='single_class_accuracy (training data)')
axs[3].legend(loc="upper left")
plt.show()
Overview of data:
index Attr1 Attr2 Attr3 Attr4 Attr5 Attr6 Attr7 \
0 0.0 0.200550 0.37951 0.39641 2.0472 32.3510 0.38825 0.249760
1 1.0 0.209120 0.49988 0.47225 1.9447 14.7860 0.00000 0.258340
2 2.0 0.248660 0.69592 0.26713 1.5548 -1.1523 0.00000 0.309060
3 3.0 0.081483 0.30734 0.45879 2.4928 51.9520 0.14988 0.092704
4 4.0 0.187320 0.61323 0.22960 1.4063 -7.3128 0.18732 0.187320
Attr8 Attr9 Attr10 Attr11 Attr12 Attr13 Attr14 Attr15 \
0 1.33050 1.1389 0.50494 0.249760 0.65980 0.166600 0.249760 497.42
1 0.99601 1.6996 0.49788 0.261140 0.51680 0.158350 0.258340 677.96
2 0.43695 1.3090 0.30408 0.312580 0.64184 0.244350 0.309060 794.16
3 1.86610 1.0571 0.57353 0.092704 0.30163 0.094257 0.092704 917.01
4 0.63070 1.1559 0.38677 0.187320 0.33147 0.121820 0.187320 1133.20
Attr16 Attr17 Attr18 Attr19 Attr20 Attr22 Attr23 Attr24 \
0 0.73378 2.6349 0.249760 0.149420 43.370 0.21402 0.119980 0.477060
1 0.53838 2.0005 0.258340 0.152000 87.981 0.24806 0.123040 0.292903
2 0.45961 1.4369 0.309060 0.236100 73.133 0.30260 0.189960 0.300091
3 0.39803 3.2537 0.092704 0.071428 79.788 0.11550 0.062782 0.171930
4 0.32211 1.6307 0.187320 0.115530 57.045 0.19832 0.115530 0.187320
Attr25 Attr26 Attr27 Attr28 Attr29 Attr30 Attr31 Attr32 \
0 0.50494 0.60411 1.45820 1.7615 5.9443 0.11788 0.149420 94.14
1 0.39542 0.43992 88.44400 16.9460 3.6884 0.26969 0.152000 122.17
2 0.28932 0.37282 86.01100 1.0627 4.3749 0.41929 0.238150 176.93
3 0.57353 0.36152 0.94076 1.9618 4.6511 0.14343 0.071428 91.37
4 0.38677 0.32211 1.41380 1.1184 4.1424 0.27884 0.115530 147.04
Attr33 Attr34 Attr35 Attr36 Attr38 Attr39 Attr40 Attr41 \
0 3.8772 0.56393 0.21402 1.7410 0.50591 0.128040 0.662950 0.051402
1 2.9876 2.98760 0.20616 1.6996 0.49788 0.121300 0.086422 0.064371
2 2.0630 1.42740 0.31565 1.3090 0.51537 0.241140 0.322020 0.074020
3 3.9948 0.37581 0.11550 1.3562 0.57353 0.088995 0.401390 0.069622
4 2.4823 0.32340 0.19832 1.6278 0.43489 0.122310 0.293040 0.096680
Attr42 Attr43 Attr44 Attr45 Attr46 Attr47 Attr48 Attr49 \
0 0.128040 114.42 71.050 1.00970 1.52250 49.394 0.185300 0.110850
1 0.145950 199.49 111.510 0.51045 1.12520 100.130 0.237270 0.139610
2 0.231170 165.51 92.381 0.94807 1.01010 96.372 0.291810 0.222930
3 0.088995 180.77 100.980 0.28720 1.56960 84.344 0.085874 0.066165
4 0.122310 141.62 84.574 0.73919 0.95787 65.936 0.188110 0.116010
Attr50 Attr51 Attr52 Attr53 Attr54 Attr55 Attr56 Attr57 \
0 2.0420 0.37854 0.25792 2.2437 2.2480 348690.0 0.121960 0.39718
1 1.9447 0.49988 0.33472 17.8660 17.8660 2304.6 0.121300 0.42002
2 1.0758 0.48152 0.48474 1.2098 2.0504 6332.7 0.241140 0.81774
3 2.4928 0.30734 0.25033 2.4524 2.4524 20545.0 0.054015 0.14207
4 1.2959 0.56511 0.40285 1.8839 2.1184 3186.6 0.134850 0.48431
Attr58 Attr59 Attr60 Attr61 Attr62 Attr63 Attr64 status year
0 0.87804 0.001924 8.4160 5.1372 82.658 4.4158 7.4277 0.0 1.0
1 0.85300 0.000000 4.1486 3.2732 107.350 3.4000 60.9870 0.0 1.0
2 0.76599 0.694840 4.9909 3.9510 134.270 2.7185 5.2078 0.0 1.0
3 0.94598 0.000000 4.5746 3.6147 86.435 4.2228 5.5497 0.0 1.0
4 0.86515 0.124440 6.3985 4.3158 127.210 2.8692 7.8980 0.0 1.0
´´´