微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 PyTorch 的 DCGAN 鉴别器准确度度量

如何解决使用 PyTorch 的 DCGAN 鉴别器准确度度量

我正在使用 PyTorch 实现 DCGAN

效果很好,因为我可以获得质量合理的生成图像,但是现在我想通过使用指标来评估 GAN 模型的健康状况,主要是本指南中介绍的指标 https://machinelearningmastery.com/practical-guide-to-gan-failure-modes/

他们的实现使用 Keras,其中 SDK 允许您在编译模型时定义所需的指标,请参阅 https://keras.io/api/models/model/在这种情况下,鉴别器的准确性,即它成功将图像识别为真实图像或生成图像的百分比。

使用 PyTorch SDK,我似乎找不到类似的功能来帮助我从模型中轻松获取此指标。

Pytorch 是否提供能够从模型中定义和提取通用指标的功能

解决方法

纯 PyTorch提供开箱即用的指标,但您自己定义这些指标非常容易。

也没有“从模型中提取指标”这样的东西。指标是指标,它们衡量(在这种情况下是鉴别器的准确性),它们不是模型固有的。

二进制精度

就您而言,您正在寻找二进制精度指标。下面的代码适用于 logits(由 discriminator 输出的非标准化概率,可能是最后一个 nn.Linear 层没有激活)或 probabilities(最后一个 nn.Linear 后跟 {{1 }} 激活):

sigmoid

用法:

import typing
import torch


class BinaryAccuracy:
    def __init__(
        self,logits: bool = True,reduction: typing.Callable[
            [
                torch.Tensor,],torch.Tensor,] = torch.mean,):
        self.logits = logits
        if logits:
            self.threshold = 0
        else:
            self.threshold = 0.5

        self.reduction = reduction

    def __call__(self,y_pred,y_true):
        return self.reduction(((y_pred > self.threshold) == y_true.bool()).float())

PyTorch Lightning 或其他第三方

您还可以在 PyTorch 之上使用 PyTorch Lightning 或其他定义指标的框架,例如 accuracy

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?