微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

PyTorch二进制分类-相同的网络结构,“更简单”的数据,但性能较差?

如何解决PyTorch二进制分类-相同的网络结构,“更简单”的数据,但性能较差?

TL; DR

您的输入数据未标准化。

  1. 采用 x_data = (x_data - x_data.mean()) / x_data.std()
  2. 提高学习率 optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

你会得到在此处输入图片说明

仅1000次迭代即可收敛。

更多细节

这两个示例之间的主要区别在于x,第一个示例中的数据以(0,0)为中心,并且方差很小。 另一方面,第二示例中的数据以92为中心,并且具有相对较大的方差。

当您随机初始化权重时,不会考虑数据中的初始偏差,该权重是基于假设输入大致呈​​正态分布在 附近的假设完成的。 优化过程几乎不可能补偿该总偏差-因此模型陷入了次优的解决方案。

将输入标准化后,通过减去平均值并除以std,优化过程将再次变得稳定,并迅速收敛为一个好的解决方案。

有关输入归一化和权重初始化的更多详细信息,您可以阅读 He等人 (ICCV 2015)中的第2.2节。

如果我无法规范化数据怎么办?

如果由于某种原因您无法提前计算均值和标准数据,则仍可以nn.BatchNorm1d在训练过程中使用它来估计和标准化数据。例如

class Model(nn.Module):
    def __init__(self, input_size, H1, output_size):
        super().__init__()
        self.bn = nn.Batchnorm1d(input_size)  # adding batchnorm
        self.linear = nn.Linear(input_size, H1)
        self.linear2 = nn.Linear(H1, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.linear(self.bn(x)))  # batchnorm the input x
        x = torch.sigmoid(self.linear2(x))
        return x

这种修改 不会 对输入数据 进行 任何更改,仅在1000个纪元后就会产生类似的收敛性:在此处输入图片说明

轻微评论

为了数值稳定,最好使用nn.BCEWithLogitsLoss代替nn.BCELoss。为此,您需要torch.sigmoidforward()输出删除,这sigmoid将在损失内进行计算。 。

解决方法

为了掌握PyTorch(以及一般的深度学习),我首先研究了一些基本的分类示例。一个这样的示例是对使用sklearn创建的非线性数据集进行分类(完整代码可在此处作为笔记本查看

n_pts = 500
X,y = datasets.make_circles(n_samples=n_pts,random_state=123,noise=0.1,factor=0.2)
x_data = torch.FloatTensor(X)
y_data = torch.FloatTensor(y.reshape(500,1))

然后使用相当基本的神经网络将其准确分类

class Model(nn.Module):
    def __init__(self,input_size,H1,output_size):
        super().__init__()
        self.linear = nn.Linear(input_size,H1)
        self.linear2 = nn.Linear(H1,output_size)

    def forward(self,x):
        x = torch.sigmoid(self.linear(x))
        x = torch.sigmoid(self.linear2(x))
        return x

    def predict(self,x):
        pred = self.forward(x)
        if pred >= 0.5:
            return 1
        else:
            return 0

当我对健康数据感兴趣时,我决定尝试使用相同的网络结构对一些基本的现实世界数据集进行分类。我从这里获取了一名患者的心率数据,并对其进行了更改,以便所有>
91的值都被标记为异常(例如a 1,所有<= 91的值都标记为a
0)。这是完全任意的,但是我只是想看看分类是如何工作的。此示例的完整笔记本在这里

对我而言不直观的是,为什么第一个示例 在1,000个历元后损失0.0016 ,而第二个示例 在10,000个历元后却损失0.4296

也许我天真地认为心率示例更容易分类。任何能帮助我理解为什么这不是我所看到的见解都会很棒!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。