如何解决神经网络不会泛化高度不平衡的数据
我对机器学习还很陌生,目前正在尝试基于严重不平衡的数据构建一个简单的前馈神经网络。数据由 nn 应该预测的 64 个不同变量(全部归一化)和 1 个二元变量(1 和 0)组成。数据由43405个数据行组成,其中2091个为1类,41314个为0类。目标为1类的高预测精度。
我实际上不太确定发生了什么,但对我来说似乎 nn 没有学习 1 类数据(这是重要的数据)。在我实施样本权重之前(我无法实施类权重),在实施样本权重(高度没有显着改变)后,输出显示的总体准确度始终 >93%,总体准确度下降到 40% 左右,但是1 级准确度保持在极低的水平。改变学习率不会改变任何东西。改变架构也是如此。
我不知道我做错了什么,也不知道如何解决这个问题。由于这对我的论文相当重要,我非常乐意提供任何帮助!!!如果我的描述中缺少任何内容或者我不够清楚,请询问!
我的代码如下:
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
from keras.callbacks import EarlyStopping
from keras import backend as K
from sklearn.model_selection import train_test_split
from keras.utils import np_utils
from sklearn.utils import class_weight
from sklearn.preprocessing import LabelEncoder
import matplotlib
import keras
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
df = pd.read_pickle('nfinaldf.pkl')
df = df.drop(columns = ['index'])
x = df.drop(columns = ['status'])
y = to_categorical(df.status)
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size = 0.33,random_state = 50)
classes = np.unique(y_train)
model = Sequential()
n_cols = x_train.shape[1]
model.add(Dense(60,activation='relu',input_shape=(n_cols,)))
model.add(Dense(200,activation='relu'))
model.add(Dense(200,activation='relu'))
model.add(Dense(2,activation = 'softmax'))
def generate_sample_weights(training_data,class_weight_dictionary):
sample_weights = [class_weight_dictionary[np.where(one_hot_row == 1)[0][0]] for one_hot_row in training_data]
return np.asarray(sample_weights)
class_weights_dict = { 0 : 1,1 : 50}
optimizer = keras.optimizers.Adam(lr=0.0001)
def sensitivity(y_true,y_pred):
true_positives = K.sum(K.round(K.clip(y_true * y_pred,1)))
possible_positives = K.sum(K.round(K.clip(y_true,1)))
return true_positives / (possible_positives + K.epsilon())
def specificity(y_true,y_pred):
true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred),1)))
possible_negatives = K.sum(K.round(K.clip(1-y_true,1)))
return true_negatives / (possible_negatives + K.epsilon())
INTERESTING_CLASS_ID = 1
def single_class_accuracy(y_true,y_pred):
class_id_true = K.argmax(y_true,axis=-1)
class_id_preds = K.argmax(y_pred,axis=-1)
accuracy_mask = K.cast(K.equal(class_id_preds,INTERESTING_CLASS_ID),'int32')
class_acc_tensor = K.cast(K.equal(class_id_true,class_id_preds),'int32') * accuracy_mask
class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask),1)
return class_acc
model.compile(optimizer = optimizer,loss = 'binary_crossentropy',metrics=[sensitivity,specificity,'accuracy',single_class_accuracy])
early_stopping_monitor = EarlyStopping(patience = 30)
history = model.fit(x_train,validation_data = (x_test,y_test),epochs = 80,callbacks=[early_stopping_monitor],sample_weight = generate_sample_weights(y_train,class_weights_dict))
score = model.evaluate(x_test,y_test,verbose = 0)
fig,axs = plt.subplots(4)
fig.suptitle('Vertically stacked subplots')
axs[0].plot(history.history['sensitivity'],label='sensitivity (training data)')
axs[0].plot(history.history['specificity'],label='specificity (validation data)')
axs[0].legend(loc="upper left")
axs[1].plot(history.history['val_sensitivity'],label='val_sensitivity (training data)')
axs[1].plot(history.history['val_specificity'],label='val_specificity (validation data)')
axs[1].legend(loc="upper left")
axs[2].plot(history.history['loss'],label='loss (training data)')
axs[2].legend(loc="upper left")
axs[3].plot(history.history['single_class_accuracy'],label='single_class_accuracy (training data)')
axs[3].legend(loc="upper left")
plt.show()
Overview of data:
index Attr1 Attr2 Attr3 Attr4 Attr5 Attr6 Attr7 \
0 0.0 0.200550 0.37951 0.39641 2.0472 32.3510 0.38825 0.249760
1 1.0 0.209120 0.49988 0.47225 1.9447 14.7860 0.00000 0.258340
2 2.0 0.248660 0.69592 0.26713 1.5548 -1.1523 0.00000 0.309060
3 3.0 0.081483 0.30734 0.45879 2.4928 51.9520 0.14988 0.092704
4 4.0 0.187320 0.61323 0.22960 1.4063 -7.3128 0.18732 0.187320
Attr8 Attr9 Attr10 Attr11 Attr12 Attr13 Attr14 Attr15 \
0 1.33050 1.1389 0.50494 0.249760 0.65980 0.166600 0.249760 497.42
1 0.99601 1.6996 0.49788 0.261140 0.51680 0.158350 0.258340 677.96
2 0.43695 1.3090 0.30408 0.312580 0.64184 0.244350 0.309060 794.16
3 1.86610 1.0571 0.57353 0.092704 0.30163 0.094257 0.092704 917.01
4 0.63070 1.1559 0.38677 0.187320 0.33147 0.121820 0.187320 1133.20
Attr16 Attr17 Attr18 Attr19 Attr20 Attr22 Attr23 Attr24 \
0 0.73378 2.6349 0.249760 0.149420 43.370 0.21402 0.119980 0.477060
1 0.53838 2.0005 0.258340 0.152000 87.981 0.24806 0.123040 0.292903
2 0.45961 1.4369 0.309060 0.236100 73.133 0.30260 0.189960 0.300091
3 0.39803 3.2537 0.092704 0.071428 79.788 0.11550 0.062782 0.171930
4 0.32211 1.6307 0.187320 0.115530 57.045 0.19832 0.115530 0.187320
Attr25 Attr26 Attr27 Attr28 Attr29 Attr30 Attr31 Attr32 \
0 0.50494 0.60411 1.45820 1.7615 5.9443 0.11788 0.149420 94.14
1 0.39542 0.43992 88.44400 16.9460 3.6884 0.26969 0.152000 122.17
2 0.28932 0.37282 86.01100 1.0627 4.3749 0.41929 0.238150 176.93
3 0.57353 0.36152 0.94076 1.9618 4.6511 0.14343 0.071428 91.37
4 0.38677 0.32211 1.41380 1.1184 4.1424 0.27884 0.115530 147.04
Attr33 Attr34 Attr35 Attr36 Attr38 Attr39 Attr40 Attr41 \
0 3.8772 0.56393 0.21402 1.7410 0.50591 0.128040 0.662950 0.051402
1 2.9876 2.98760 0.20616 1.6996 0.49788 0.121300 0.086422 0.064371
2 2.0630 1.42740 0.31565 1.3090 0.51537 0.241140 0.322020 0.074020
3 3.9948 0.37581 0.11550 1.3562 0.57353 0.088995 0.401390 0.069622
4 2.4823 0.32340 0.19832 1.6278 0.43489 0.122310 0.293040 0.096680
Attr42 Attr43 Attr44 Attr45 Attr46 Attr47 Attr48 Attr49 \
0 0.128040 114.42 71.050 1.00970 1.52250 49.394 0.185300 0.110850
1 0.145950 199.49 111.510 0.51045 1.12520 100.130 0.237270 0.139610
2 0.231170 165.51 92.381 0.94807 1.01010 96.372 0.291810 0.222930
3 0.088995 180.77 100.980 0.28720 1.56960 84.344 0.085874 0.066165
4 0.122310 141.62 84.574 0.73919 0.95787 65.936 0.188110 0.116010
Attr50 Attr51 Attr52 Attr53 Attr54 Attr55 Attr56 Attr57 \
0 2.0420 0.37854 0.25792 2.2437 2.2480 348690.0 0.121960 0.39718
1 1.9447 0.49988 0.33472 17.8660 17.8660 2304.6 0.121300 0.42002
2 1.0758 0.48152 0.48474 1.2098 2.0504 6332.7 0.241140 0.81774
3 2.4928 0.30734 0.25033 2.4524 2.4524 20545.0 0.054015 0.14207
4 1.2959 0.56511 0.40285 1.8839 2.1184 3186.6 0.134850 0.48431
Attr58 Attr59 Attr60 Attr61 Attr62 Attr63 Attr64 status year
0 0.87804 0.001924 8.4160 5.1372 82.658 4.4158 7.4277 0.0 1.0
1 0.85300 0.000000 4.1486 3.2732 107.350 3.4000 60.9870 0.0 1.0
2 0.76599 0.694840 4.9909 3.9510 134.270 2.7185 5.2078 0.0 1.0
3 0.94598 0.000000 4.5746 3.6147 86.435 4.2228 5.5497 0.0 1.0
4 0.86515 0.124440 6.3985 4.3158 127.210 2.8692 7.8980 0.0 1.0
´´´
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。