微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

绘制阈值precision_recall 曲线 matplotlib/sklearn.metrics

如何解决绘制阈值precision_recall 曲线 matplotlib/sklearn.metrics

我正在尝试绘制精度/召回曲线的阈值。我只是使用 MNSIT 数据,并以《使用 scikit-learn、keras 和 TensorFlow 进行机器学习动手》一书中的示例为例。尝试训练模型检测 5 的图像。不知道你需要看多少代码。我已经为训练集制作了混淆矩阵,并计算了精度和召回值以及阈值。我已经绘制了 pre/rec 曲线,书中的示例说要添加标签、壁架、网格并突出显示阈值,但代码在我在下面放置星号的书中被切断了。除了如何让阈值显示在情节上之外,我能够弄清楚所有事情。我已经附上了一张图,说明了书中的图表与我所拥有的相比。 这就是这本书的内容

enter image description here

对比我的图表:

enter image description here

我无法显示带有两个阈值点的红色点线。有谁知道我将如何做到这一点?这是我的代码如下:

from sklearn.metrics import precision_recall_curve

precisions,recalls,thresholds = precision_recall_curve(y_train_5,y_scores)

def plot_precision_recall_vs_thresholds(precisions,thresholds):
    plt.plot(thresholds,precisions[:-1],"b--",label="Precision")
    plt.plot(thresholds,recalls[:-1],"g--",label="Recall")
    plt.xlabel("Threshold")
    plt.legend(bBox_to_anchor=(1.05,1),loc='upper left',borderaxespad=0.)
    plt.grid(b=True,which="both",axis="both",color='gray',linestyle='-',linewidth=1)

plot_precision_recall_vs_thresholds(precisions,thresholds)
plt.show()

我知道这里有很多关于 sklearn 的问题,但似乎没有一个问题包括让那条红线出现。非常感谢您的帮助!

解决方法

您可以使用以下代码绘制水平线和垂直线:

plt.axhline(y_value,c='r',ls=':')
plt.axvline(x_value,ls=':')
,

这应该以正确的方式工作:

def plot_precision_recall_vs_threshold(precisions,recalls,thresholds):
    recall_80_precision = recalls[np.argmax(precisions >= 0.80)]
    threshold_80_precision = thresholds[np.argmax(precisions >= 0.80)]
    
    plt.plot(thresholds,precisions[:-1],"b--",label="Precision",linewidth=2)
    plt.plot(thresholds,recalls[:-1],"g-",label="Recall",linewidth=2)
    plt.xlabel("Threshold")
    plt.plot([threshold_80_precision,threshold_80_precision],[0.,0.8],"r:")
    plt.axis([-4,4,1])
    plt.plot([-4,[0.8,"r:")
    plt.plot([-4,[recall_80_precision,recall_80_precision],"r:")
    plt.plot([threshold_80_precision],[0.8],"ro") 
    plt.plot([threshold_80_precision],[recall_80_precision],"ro")
    plt.grid(True)
    plt.legend()
    plt.show()

我在尝试复制本书中的代码时遇到了这段代码。结果 @ageron 将所有资源都放在了他的 github 页面上。你可以看看here

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。