微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Pytorch - Tensorboard - Precision-Recall 曲线只显示一个点

如何解决Pytorch - Tensorboard - Precision-Recall 曲线只显示一个点

我正在训练一个包含 6 个类别的分类器,并且我的 pr 曲线具有以下功能

def predict_class_probabilities(model,features):
    model.eval()  # Switch to evaluation mode
    with no_grad():  # Don't want to calculate gradients here,only when training
        predictions = model(features)
        # Get class prediction probabilities
        prediction_class_probabilities = F.softmax(predictions,dim=1)

    return prediction_class_probabilities

def add_pr_curves_tensorboard(summary_writer,model,features,labels,global_step=0):

    prediction_class_probabilities = predict_class_probabilities(model,features)

    # Iterate over each class and add pr curve to summary_writer
    for class_index in range(num_classes):
        # Need binary prediction for class class_index for the add_pr_curve method
        binary_class_label = labels == class_index
        # Prediction probability for class class_index
        predictions_class_probability = prediction_class_probabilities[:,class_index]
        tag = 'class {}'.format(class_index)

    summary_writer.add_pr_curve(tag,binary_class_label,predictions_class_probability,global_step=global_step)

在训练我的模型后,我获得了 93% 的准确率(0 类和 1 类分别为 92% 和 96%),但是我的 0 类和 1 类的 pr 曲线看起来像这样(其他曲线看起来相似): pr curve class 0pr curve class 1。有人可以告诉我我在这里做错了什么吗? 最良好的祝愿, 托比亚斯

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。