微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用批处理计算 BCE 反向传播中的 dx

如何解决使用批处理计算 BCE 反向传播中的 dx

我正在尝试使用批处理从头开始反向传播,但在计算 dx 时遇到问题。首先,我想先定义变量以避免混淆:

a - The activation value calculated by passing z through an activation function

z - The value before the activation function of the layer

x - The inputs into the layer

w - The weights that connect the inputs to the output nodes

da - The derivative of a

dz - The derivative of z

dx - The derivative of x

我知道这是 x 的导数:

dx = w.T*dz
Note: * means dot and .T means transpose

现在让我介绍一下这个问题。假设我有一个具有 2 个输入、3 个输出节点和 5 个批次大小的神经网络。我将如何计算 dx?在这种情况下,权重在转置之前的形状为 (z,x) 或 (3,2),而 dz 的形状为 (z,batches) 或 (3,5)。如果我使用上面的公式,我会得到 (x,batches) 或 (2,5) 的形状。在使用上面的公式得到 dx(导致形状为 (2,1))之后,我会取最后一个维度的总和吗?下面是使用虚构值的点积表示:

     w.T          *          dz            =        dx
                      [[1,2,3,4,5],[[1,0.5,1],*    [1,=   [[2.5,5,7.5,10,12.5],[-1,-1,-0.5]       [1,5]]         [-2.5,-5,-7.5,-10,-12.5]] 

解决方法

你做的一切都是正确的。在反向传播中,X 始终需要与 dX 具有相同的维度。如果 X 是形状 (2,5) 的中间结果,则梯度也具有形状 (2,5)。通过这种方式,您可以更新矩阵 X。现在在您的情况下,矩阵 X 是输入矩阵,您将永远不会更新它。您只需要更新 W。

如果 X 是隐藏层的结果,则您对反向传播梯度的计算是正确的。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?