用交叉验证训练 8 个不同的分类器,对同一个文件给出相同的准确度?

如何解决用交叉验证训练 8 个不同的分类器,对同一个文件给出相同的准确度?

我有下面的脚本,它应该使用交叉验证来训练不同的模型,然后计算平均准确度,以便我可以使用最佳模型进行分类任务。但我对每个分类器都得到了相同的结果。

结果如下:

---Filename in processed................ corpusAmazon_train
etiquette  : [0 1]
Embeddings bert model used.................... :  sm
Model name: Model_LSVC_ovr
------------cross val predict used---------------- 

accuracy with cross_val_predict : 0.6582974014576258
corpusAmazon_train file terminated--- 

---------------cross val score used ----------------------- 

[0.66348722 0.66234262 0.63334605 0.66959176 0.66081648 0.6463182
 0.66730256 0.65572519 0.65648855 0.66755725]
0.66 accuracy with a standard deviation of 0.01 

Model name: Model_G_NB
------------cross val predict used---------------- 

accuracy with cross_val_predict : 0.6582974014576258
corpusAmazon_train file terminated--- 

---------------cross val score used ----------------------- 

[0.66348722 0.66234262 0.63334605 0.66959176 0.66081648 0.6463182
 0.66730256 0.65572519 0.65648855 0.66755725]
0.66 accuracy with a standard deviation of 0.01 

Model name: Model_LR
------------cross val predict used---------------- 

accuracy with cross_val_predict : 0.6582974014576258
corpusAmazon_train file terminated--- 

---------------cross val score used ----------------------- 

[0.66348722 0.66234262 0.63334605 0.66959176 0.66081648 0.6463182
 0.66730256 0.65572519 0.65648855 0.66755725]
0.66 accuracy with a standard deviation of 0.01 

使用 cross_validation 的代码行:

models_list = {'Model_LSVC_ovr': model1,'Model_G_NB': model2,'Model_LR': model3,'Model_RF': model4,'Model_KN': model5,'Model_MLP': model6,'Model_LDA': model7,'Model_XGB': model8}

# cross_validation
def cross_validation(features,ylabels,models_list,n,lge_model):

    cv_splitter = KFold(n_splits=10,shuffle=True,random_state=42)
    features,s = get_flaubert_layer(features,lge_model)
    for model_name,model in models_list.items():
        print("Model name: {}".format(model_name))
        print("------------cross val predict used----------------","\n")
        y_pred = cross_val_predict(model,features,cv=cv_splitter,verbose=1)
        accuracy_score_predict = accuracy_score(ylabels,y_pred)
        print("accuracy with cross_val_predict :",accuracy_score_predict)

        print("---------------cross val score used -----------------------","\n")
        scores = cross_val_score(model,scoring='accuracy',cv=cv_splitter)

        print("%0.2f accuracy with a standard deviation of %0.2f" % (accuracy_score_mean,accuracy_score_std),"\n")

即使在使用 cross_val_score 时,模型的准确度也相同。任何想法,也许我在 cross_validation 函数中使用了 random_state ?

模型定义代码

def classifiers_b():

    model1 = LinearSVC()
    model2 = GaussianNB()  # MultinomialNB() X cannot be a non-negative
    model3 = LogisticRegression()
    model4 = RandomForestClassifier()
    model5 = KNeighborsClassifier()
    model6 = MLPClassifier(hidden_layer_sizes=(50,100,50),max_iter=500,activation='relu',solver='adam',random_state=1)
    model8 = XGBClassifier(eval_metric="logloss")
    model7 = LineardiscriminantAnalysis()

    #models_list = {'Model_LSVC_ovr': model1,'Model_XGB': model8}

解决方法

我建议为每个模型使用一个管道。看起来您在每次迭代中都在同一模型上执行 CV。您可以查看文档 here 以获取有关如何使用它们的更多信息。然后对每个模型管道执行 CV。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?