如何解决取数据帧中每一行的 predict_proba 的最大值
我在使用 predict_proba
函数时遇到问题。我有一个多类分类问题并使用随机森林分类器。我想打印数据和相应的预测类别+该类别的预测概率。首先,我可以设法做到 1 个预测。
一个示例的代码
y_pred=pickle_model.predict(df_test)
y_pred_prob = pickle_model.predict_proba(df_test)
ix = y_pred_prob.argmax(1).item()
list = []
list.append(y_pred[iy])
list.append(f'{y_pred_prob[0,ix]:.2%}')
然而,当我给出一个包含超过 1 行项目的测试数据集时,我一直在挣扎。
我尝试了以下函数并将其逐行应用于 df 。但是,我只能得到一系列概率。无法弄清楚如何获得每行的最大值。当只有一行时,我使用 argmax
如下所示。
这是多样本测试数据的代码:
def get_predict_proba(row,model):
return model.predict_proba(row.values.reshape(1,-1))
df['predicted_category'] = pickle_model.predict(df)
df['confidence'] = df.apply(lambda row: get_predict_proba(row,pickle_model),axis=1)
这给出了这样的输出:
ID | 功能1 | 功能 2 | predicted_category | 信心 |
---|---|---|---|---|
### | ######## | ######## | category_name 1 | predict_proba 数组 |
### | ######## | ######## | category_name 2 | predict_proba 数组 |
预期的输出是这样的:
ID | 功能1 | 功能 2 | predicted_category | 信心 |
---|---|---|---|---|
### | ######## | ######## | category_name 1 | category_name 1 的概率值 |
### | ######## | ######## | category_name 2 | category_name 2 的概率值 |
我的第二个问题是 predict_proba
函数在具有多个类的随机森林分类器上的可靠性。它真的给出了正确的比例吗?我有一个类之间不平衡的数据集。如果没有,是否有更好的选择或解决此问题的方法?
感谢您的帮助。
解决方法
我已经解决了像这样编辑 get_predict_proba 函数的问题:
def get_predict_proba(row,model):
y_pred_prob=model.predict_proba(row.values.reshape(1,-1))
ix = y_pred_prob.argmax(1).item()
return (f'{y_pred_prob[0,ix]:.2%}')
我仍然需要更深入的关于 predict_proba 的信息,以及它如何在具有不平衡类的多类分类器上工作。此外,如果有更有效的方法来解决这个问题,我很乐意看到。谢谢
,预测值的第一个“predict_proba”
正如您所说,predict_proba
以您拥有的每个类的概率返回和数组。当然,这个数组的最大值或最大值将对应于预测的类。因此,简单的解决方案是返回 predict_proba
数组的最大值:
def get_predict_proba(row,model):
return max(model.predict_proba(row.values.reshape(1,-1)))
不平衡类中的第二个问题“predict_proba”
作为定义,predict_proba
是 “输入样本的预测类别概率计算为森林中树木的平均预测类别概率。单棵树的类别概率是叶子中相同类别的样本”。
这意味着您正在评估随机森林中每棵树的预测,并由此获得概率。如果您的随机森林在类别不平衡的情况下表现良好,那么这是一个不错的方法。
总而言之,我将专注于获得一个好的随机森林模型,如果它能够很好地处理不平衡的类,那么 predict_proba
将具有代表性。如果随机森林不够好,你将不得不使用一些技术来解决不平衡类的问题(例如过采样或欠采样)。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。