神经网络不会泛化高度不平衡的数据

如何解决神经网络不会泛化高度不平衡的数据

我对机器学习还很陌生,目前正在尝试基于严重不平衡的数据构建一个简单的前馈神经网络。数据由 nn 应该预测的 64 个不同变量(全部归一化)和 1 个二元变量(1 和 0)组成。数据由43405个数据行组成,其中2091个为1类,41314个为0类。目标为1类的高预测精度。

我实际上不太确定发生了什么,但对我来说似乎 nn 没有学习 1 类数据(这是重要的数据)。在我实施样本权重之前(我无法实施类权重),在实施样本权重(高度没有显着改变)后,输出显示的总体准确度始终 >93%,总体准确度下降到 40% 左右,但是1 级准确度保持在极低的水平。改变学习率不会改变任何东西。改变架构也是如此。

我不知道我做错了什么,也不知道如何解决这个问题。由于这对我的论文相当重要,我非常乐意提供任何帮助!!!如果我的描述中缺少任何内容或者我不够清楚,请询问!

History of nn training

我的代码如下:

from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
from keras.callbacks import EarlyStopping
from keras import backend as K
from sklearn.model_selection import train_test_split
from keras.utils import np_utils
from sklearn.utils import class_weight
from sklearn.preprocessing import LabelEncoder
import matplotlib
import keras
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf



df = pd.read_pickle('nfinaldf.pkl')
df = df.drop(columns = ['index'])

x = df.drop(columns = ['status'])
y = to_categorical(df.status)



x_train,x_test,y_train,y_test = train_test_split(x,y,test_size = 0.33,random_state = 50)

classes = np.unique(y_train)

model = Sequential() 

n_cols = x_train.shape[1]

model.add(Dense(60,activation='relu',input_shape=(n_cols,)))
model.add(Dense(200,activation='relu'))
model.add(Dense(200,activation='relu'))
model.add(Dense(2,activation = 'softmax'))

def generate_sample_weights(training_data,class_weight_dictionary): 
    sample_weights = [class_weight_dictionary[np.where(one_hot_row == 1)[0][0]] for one_hot_row in training_data]
    return np.asarray(sample_weights)
class_weights_dict = { 0 : 1,1 : 50}
 
optimizer = keras.optimizers.Adam(lr=0.0001)

def sensitivity(y_true,y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred,1)))
    possible_positives = K.sum(K.round(K.clip(y_true,1)))
    return true_positives / (possible_positives + K.epsilon())

def specificity(y_true,y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred),1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true,1)))
    return true_negatives / (possible_negatives + K.epsilon())

INTERESTING_CLASS_ID = 1

def single_class_accuracy(y_true,y_pred):
    class_id_true = K.argmax(y_true,axis=-1)
    class_id_preds = K.argmax(y_pred,axis=-1)
    accuracy_mask = K.cast(K.equal(class_id_preds,INTERESTING_CLASS_ID),'int32')
    class_acc_tensor = K.cast(K.equal(class_id_true,class_id_preds),'int32') * accuracy_mask
    class_acc = K.sum(class_acc_tensor) / K.maximum(K.sum(accuracy_mask),1)
    return class_acc

model.compile(optimizer = optimizer,loss = 'binary_crossentropy',metrics=[sensitivity,specificity,'accuracy',single_class_accuracy])

early_stopping_monitor = EarlyStopping(patience = 30)

history = model.fit(x_train,validation_data = (x_test,y_test),epochs = 80,callbacks=[early_stopping_monitor],sample_weight = generate_sample_weights(y_train,class_weights_dict))
score = model.evaluate(x_test,y_test,verbose = 0)



fig,axs = plt.subplots(4)
fig.suptitle('Vertically stacked subplots')
axs[0].plot(history.history['sensitivity'],label='sensitivity (training data)')
axs[0].plot(history.history['specificity'],label='specificity (validation data)')
axs[0].legend(loc="upper left")
axs[1].plot(history.history['val_sensitivity'],label='val_sensitivity (training data)')
axs[1].plot(history.history['val_specificity'],label='val_specificity (validation data)')
axs[1].legend(loc="upper left")
axs[2].plot(history.history['loss'],label='loss (training data)')
axs[2].legend(loc="upper left")
axs[3].plot(history.history['single_class_accuracy'],label='single_class_accuracy (training data)')
axs[3].legend(loc="upper left")
plt.show()



Overview of data:

index     Attr1    Attr2    Attr3   Attr4    Attr5    Attr6     Attr7  \
0    0.0  0.200550  0.37951  0.39641  2.0472  32.3510  0.38825  0.249760   
1    1.0  0.209120  0.49988  0.47225  1.9447  14.7860  0.00000  0.258340   
2    2.0  0.248660  0.69592  0.26713  1.5548  -1.1523  0.00000  0.309060   
3    3.0  0.081483  0.30734  0.45879  2.4928  51.9520  0.14988  0.092704   
4    4.0  0.187320  0.61323  0.22960  1.4063  -7.3128  0.18732  0.187320   

     Attr8   Attr9   Attr10    Attr11   Attr12    Attr13    Attr14   Attr15  \
0  1.33050  1.1389  0.50494  0.249760  0.65980  0.166600  0.249760   497.42   
1  0.99601  1.6996  0.49788  0.261140  0.51680  0.158350  0.258340   677.96   
2  0.43695  1.3090  0.30408  0.312580  0.64184  0.244350  0.309060   794.16   
3  1.86610  1.0571  0.57353  0.092704  0.30163  0.094257  0.092704   917.01   
4  0.63070  1.1559  0.38677  0.187320  0.33147  0.121820  0.187320  1133.20   

    Attr16  Attr17    Attr18    Attr19  Attr20   Attr22    Attr23    Attr24  \
0  0.73378  2.6349  0.249760  0.149420  43.370  0.21402  0.119980  0.477060   
1  0.53838  2.0005  0.258340  0.152000  87.981  0.24806  0.123040  0.292903   
2  0.45961  1.4369  0.309060  0.236100  73.133  0.30260  0.189960  0.300091   
3  0.39803  3.2537  0.092704  0.071428  79.788  0.11550  0.062782  0.171930   
4  0.32211  1.6307  0.187320  0.115530  57.045  0.19832  0.115530  0.187320   

    Attr25   Attr26    Attr27   Attr28  Attr29   Attr30    Attr31  Attr32  \
0  0.50494  0.60411   1.45820   1.7615  5.9443  0.11788  0.149420   94.14   
1  0.39542  0.43992  88.44400  16.9460  3.6884  0.26969  0.152000  122.17   
2  0.28932  0.37282  86.01100   1.0627  4.3749  0.41929  0.238150  176.93   
3  0.57353  0.36152   0.94076   1.9618  4.6511  0.14343  0.071428   91.37   
4  0.38677  0.32211   1.41380   1.1184  4.1424  0.27884  0.115530  147.04   

   Attr33   Attr34   Attr35  Attr36   Attr38    Attr39    Attr40    Attr41  \
0  3.8772  0.56393  0.21402  1.7410  0.50591  0.128040  0.662950  0.051402   
1  2.9876  2.98760  0.20616  1.6996  0.49788  0.121300  0.086422  0.064371   
2  2.0630  1.42740  0.31565  1.3090  0.51537  0.241140  0.322020  0.074020   
3  3.9948  0.37581  0.11550  1.3562  0.57353  0.088995  0.401390  0.069622   
4  2.4823  0.32340  0.19832  1.6278  0.43489  0.122310  0.293040  0.096680   

     Attr42  Attr43   Attr44   Attr45   Attr46   Attr47    Attr48    Attr49  \
0  0.128040  114.42   71.050  1.00970  1.52250   49.394  0.185300  0.110850   
1  0.145950  199.49  111.510  0.51045  1.12520  100.130  0.237270  0.139610   
2  0.231170  165.51   92.381  0.94807  1.01010   96.372  0.291810  0.222930   
3  0.088995  180.77  100.980  0.28720  1.56960   84.344  0.085874  0.066165   
4  0.122310  141.62   84.574  0.73919  0.95787   65.936  0.188110  0.116010   

   Attr50   Attr51   Attr52   Attr53   Attr54    Attr55    Attr56   Attr57  \
0  2.0420  0.37854  0.25792   2.2437   2.2480  348690.0  0.121960  0.39718   
1  1.9447  0.49988  0.33472  17.8660  17.8660    2304.6  0.121300  0.42002   
2  1.0758  0.48152  0.48474   1.2098   2.0504    6332.7  0.241140  0.81774   
3  2.4928  0.30734  0.25033   2.4524   2.4524   20545.0  0.054015  0.14207   
4  1.2959  0.56511  0.40285   1.8839   2.1184    3186.6  0.134850  0.48431   

    Attr58    Attr59  Attr60  Attr61   Attr62  Attr63   Attr64  status  year  
0  0.87804  0.001924  8.4160  5.1372   82.658  4.4158   7.4277     0.0   1.0  
1  0.85300  0.000000  4.1486  3.2732  107.350  3.4000  60.9870     0.0   1.0  
2  0.76599  0.694840  4.9909  3.9510  134.270  2.7185   5.2078     0.0   1.0  
3  0.94598  0.000000  4.5746  3.6147   86.435  4.2228   5.5497     0.0   1.0  
4  0.86515  0.124440  6.3985  4.3158  127.210  2.8692   7.8980     0.0   1.0  
´´´


  

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)> insert overwrite table dwd_trade_cart_add_inc > select data.id, > data.user_id, > data.course_id, > date_format(
错误1 hive (edu)> insert into huanhuan values(1,'haoge'); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive> show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 <configuration> <property> <name>yarn.nodemanager.res