如何解决使用交叉验证的多性能指标
我想使用准确率、准确率、召回率和 F-measure 作为性能指标。在只是准确性的情况下,代码工作正常,但是当有很多指标时,我会出错。我想知道我该怎么做?。
import matplotlib.pyplot as plt
from sklearn import model_selection
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.discriminant_analysis import LineardiscriminantAnalysis
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.metrics import make_scorer,accuracy_score,precision_score,recall_score,f1_score
scoring = {'accuracy' : make_scorer(accuracy_score),'precision' : make_scorer(precision_score),'recall' : make_scorer(recall_score),'f1_score' : make_scorer(f1_score)}
# load dataset
# prepare configuration for cross validation test harness
seed = 7
# prepare models
models = []
models.append(('LR',LogisticRegression()))
models.append(('LDA',LineardiscriminantAnalysis()))
#models.append(('KNN',KNeighborsClassifier()))
models.append(('CART',DecisionTreeClassifier()))
models.append(('NB',GaussianNB()))
models.append(('SVM',SVC()))
# evaluate each model in turn
results = []
names = []
#scoring = 'accuracy'
for name,model in models:
kfold = model_selection.KFold(n_splits=5,random_state=seed)
cv_results = model_selection.cross_validate(model,X_,y,cv=kfold,scoring=scoring)
results.append(cv_results)
names.append(name)
'''
msg = "%s: %f (%f)" % (name,cv_results['accuracy'].mean(),cv_results['accuracy'].std())
msg2 = "%s: %f (%f)" % (name,cv_results['precision'].mean(),cv_results['precision'].std())
msg3 = "%s: %f (%f)" % (name,cv_results['recall'].mean(),cv_results['recall'].std())
msg4 = "%s: %f (%f)" % (name,cv_results['f1_score'].mean(),cv_results['f1_score'].std())
print(msg)
print(msg2)
print(msg3)
print(msg4)
'''
下面的代码用于显示模型的准确性结果,以防我们将准确性作为唯一的评分。我想编辑它并使其适用于我有许多评分功能的上述情况。我该怎么做?
# Boxplot algorithm comparison
fig = plt.figure()
fig.suptitle('Algorithm Comparison')
ax = fig.add_subplot(111)
plt.Boxplot(results)
ax.set_xticklabels(names)
plt.show()
results 有这些值,我想知道如何获得指标分数:
[{'fit_time': array([0.05684781,0.03089881,0.04285073,0.03789902,0.04088998]),'score_time': array([0.00798011,0.00497937,0.00498676,0.00598478,0.00398898]),'test_accuracy': array([0.95977011,0.94827586,0.96551724,0.95677233,0.94524496]),'test_precision': array([0.95209581,0.94886364,0.97633136,0.97701149,0.93785311]),'test_recall': array([0.96363636,0.95375723,0.93922652,0.95402299]),'test_f1': array([0.95783133,0.96491228,0.95774648,0.94586895])},{'fit_time': array([0.01396322,0.00897574,0.01296639,0.0089767,0.01097035]),'score_time': array([0.0069809,0.0079782,0.00698042,0.0069809,0.00598478]),'test_accuracy': array([0.97701149,0.95402299,0.96264368,0.95389049,0.97982709]),'test_precision': array([0.99371069,0.97058824,0.99382716,1.,0.99408284]),'test_recall': array([0.95757576,0.9375,0.93063584,0.91160221,0.96551724]),'test_f1': array([0.97530864,0.96119403,0.97959184])},{'fit_time': array([0.00698161,0.00698113,0.0039897,0.00498629]),'score_time': array([0.00598383,0.00398827,0.00498652]),'test_accuracy': array([1.,0.99711816,1. ]),'test_precision': array([1.,1.]),'test_recall': array([1.,0.99447514,'test_f1': array([1.,0.99722992,1. ])},{'fit_time': array([0.00398946,0.00399137,0.00498724,0.00299191,0.00299263]),'score_time': array([0.00398922,0.00498629,0.00697994,0.00498533,0.00698185]),'test_accuracy': array([0.87068966,0.89655172,0.90229885,0.88760807,0.88184438]),'test_precision': array([0.78571429,0.83018868,0.83574879,0.82272727,0.80930233]),'test_f1': array([0.88,0.90721649,0.91052632,0.90274314,0.89460154])},{'fit_time': array([0.03992987,0.04884362,0.04388309,0.03992462,0.03992629]),'score_time': array([0.01694345,0.01100636,0.01097107,0.0119369,0.01093674]),'test_accuracy': array([0.9683908,0.95689655,0.97413793,'test_precision': array([0.99358974,0.9939759,'test_recall': array([0.93939394,0.91477273,0.95977011]),'test_f1': array([0.96573209,0.95548961,0.97345133,0.97947214])}]
解决方法
如documentation中所述:
对于单个指标评估,其中评分参数是字符串、可调用或无,键将是 - ['test_score','fit_time','score_time']
对于多指标评估,返回值是一个带有以下键的字典 - ['test_
在你的情况下,你可以获得准确率/召回率/等
cv_results["test_accuracy"]
cv_results["test_recall"]
...
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。