微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在sklearn中获取预测值和误差度量

我有两个单独的 python函数,其中一个使用cross_val_predict返回数据集的预测值,另一个使用cross_validate返回多个错误度量值.下面显示的是用于获取度量值的方法(我已经实现了类似的方法获取预测).

def metric_val(folds):
.
.
.
scoring = {'r_score': 'r2','abs_error': 'neg_mean_absolute_error','squared_error': 'neg_mean_squared_error'}

scores = cross_validate(best_svr,X,y,scoring=scoring,cv=folds,return_train_score=True)

print("****\nR2 :","",scores['test_r_score'].mean(),"| MAE :",scores['test_abs_error'].mean(),)
return prediction

我不想同时使用这两个函数,因为它的计算成本很高.是否有单一的方法或另一种方法来获得预测和指标?

解决方法

有可能装备一个得分手,以便它返回预测,虽然这有点像黑客.这是怎么做的:

cross_validate()函数可以使用自定义评分函数.评分函数必须返回一个数字,但您可以在函数内部执行任何操作.由于您拥有clf和所有测试数据,只需保存clf.predict()的输出,然后返回一个虚拟值以保持得分者满意.有关详细信息,请参阅Implementing your own scoring object上的sklearn docs.

像这样:

from sklearn import svm,datasets
from sklearn.model_selection import train_test_split,cross_validate,cross_val_predict

# example data
iris = datasets.load_iris()
X,y = iris.data,iris.target 
clf = svm.SVC(probability=True,random_state=0)

定义自定义get_preds()函数,将其作为得分者隐藏:

def get_preds(clf,y): # y is required for a scorer but we won't use it
    with open("pred.csv","ab+") as f: # append each fold to file
        np.savetxt(f,clf.predict(X))
    return 0

scoring = {'preds': get_preds,'accuracy': 'accuracy','recall': 'recall_macro'} # add desired scorers here

k = 5
cross_validate(clf,return_train_score=True,cv = k)

重新加载get_preds(),重新整形以匹配折叠集,并在折叠中平均:

preds = np.loadtxt("pred.csv").reshape(k,len(X))
my_preds = np.mean(my_preds,axis=0).round()

与cross_val_predict()预测比较:

cv_preds = cross_val_predict(clf,cv=k)

np.equal(my_preds,cv_preds).sum() # 487 out of 500

我们在makehift get_preds()方法和cross_val_predict()之间看到了几乎完美的一致.小分歧可能是由于我的平均方法与cross_val_predict不同(我只是舍入到最接近的整数类,不是非常复杂),或者它可能与sklearn cross-validation docs中这个稍微神秘的音符有关:

Note that the result of this computation may be slightly different from those obtained using cross_val_score as the elements are grouped in different ways.

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐