计算纯python中NN的梯度

如何解决计算纯python中NN的梯度

import numpy

# Data and parameters

X  = numpy.array([[-1.086,0.997,0.283,-1.506]])
T  = numpy.array([[-0.579]])
W1 = numpy.array([[-0.339,-0.047,0.746,-0.319,-0.222,-0.217],[ 1.103,1.093,0.502,0.193,0.369,0.745],[-0.468,0.588,-0.627,0.454,-0.714],[-0.070,-0.431,-0.128,-1.399,-0.886,-0.350]])
W2 = numpy.array([[ 0.379,-0.071,0.001,0.281,-0.359,0.116],[-0.329,-0.705,-0.160,0.234,0.138,-0.005],[ 0.977,0.169,0.400,0.914,-0.528,-0.424],[ 0.712,-0.326,0.012,0.437,0.364,0.716],[ 0.611,-0.315,0.325,0.128,-0.541],[ 0.579,0.330,0.019,-0.095,-0.489,0.081]])
W3 = numpy.array([[ 0.191,-0.339,0.474,-0.448,-0.867,0.424],[-0.165,-0.051,-0.342,-0.656,0.512,-0.281],[ 0.678,-0.443,-0.299,-0.495],[ 0.852,0.067,0.470,-0.517,0.074,0.481],[-0.137,0.421,-0.557,0.155,-0.155],[ 0.262,-0.807,0.291,1.061,-0.010,0.014]])
W4 = numpy.array([[ 0.073],[-0.760],[ 0.174],[-0.655],[-0.175],[ 0.507]])
B1 = numpy.array([-0.760,0.174,-0.655,-0.175,0.507,-0.300])
B2 = numpy.array([ 0.205,0.413,0.114,-0.560,-0.136,0.800])
B3 = numpy.array([-0.827,-0.113,-0.225,0.049,0.305,0.657])
B4 = numpy.array([-0.270])

# Forward pass

Z1 = X.dot(W[0])+B[0]
A1 = numpy.maximum(0,Z1)
Z2 = A1.dot(W[1])+B[1]
A2 = numpy.maximum(0,Z2)
Z3 = A2.dot(W[2])+B[2]
A3 = numpy.maximum(0,Z3)
Y  = A3.dot(W[3])+B[3];

# Error

err = ((Y-T)**2).mean()

鉴于这个例子,我想实现反向传播,并获得关于权重和偏置参数的梯度。显然,最后一层的梯度如下:

DY = 2*(Y-T)
DB4 = DY.mean(axis=0)
DW4 = A3.T.dot(DY) / len(X)
DZ3 = DY.dot(W4.T)*(Z3 > 0)

我确实知道使用链式法则计算不同的导数,但我不太明白您是如何得出这个解决方案的。

解决方法

例如,DYerrY 的导数,所以

d/dY (Y - T)**2 == 2 * (Y - T)

这是一个普通的旧衍生品,尚无链式法则。

看起来像 DB4,使用链式法则:

d/dB[3] err == d/dB[3] (A3 @ W[3] + B[3] - T)**2
== 2 * (A3 @ W[3] + B[3] - T) * d/dB[3] (A3 @ W[3] + B[3] - T)
== 2 * (A3 @ W[3] + B[3] - T) * 1
== 2 * (Y - T)
== DY

DW4 是:

d/dW[3] err == d/dW[3] (A3 @ W[3] + B[3] - T)**2
== 2 * (A3 @ W[3] + B[3] - T) @ (d/dW[3] (A3 @ W[3] + B[3] - T))
== 2 * (Y - T) @ A3.T
[must match matrix shape]
== A3.T @ DY

A3.T @ DY 的诀窍在于 d/dW[3] (A3 @ W[3]) = A3.Thttps://math.stackexchange.com/questions/1866757/not-understanding-derivative-of-a-matrix-matrix-product

为了在计算A3时通过DZ3 == d/dZ3 err区分,应该考虑激活函数(TBH,我认为Y = A3.dot(W[3])+B[3]应该是Y = numpy.maximum(0,A3.dot(W[3])+B[3]),因为最终输出应该是激活函数的结果,但也许您的网络架构没有这样做),在您的情况下是 ReLU

,

让我们使用(偏)导数的链式法则和矩阵微分法则,参考下图显示了神经网络的最后一个隐藏层,用于回归(MSE)误差的反向传播:

enter image description here

E = err = (Y - T)**2(对批次取平均值来计算 MSE)

DY = ∂E/∂Y = 2 * (Y - T)

∂E/∂W3 = (∂E/∂Y).(∂Y/∂W3)
= DY。 (∂/∂W3 (A3.W3+B3)) = DY.A3.T

= A3.T.DY (对训练批次 X 中的所有训练样本取平均值:求和除以批次大小 |X|)

∂E/∂B3 = (∂E/∂Y).(∂Y/∂B3)
= DY。 (∂/∂B3 (A3.W3+B3)) = DY.1

= DY(对批次中的所有示例取平均值)

∂E/∂Z3
= (∂E/∂Y).(∂Y/∂A3).(∂A3/∂Z3)

= DY.(∂/∂A​​3 (A3.W3+B3)).(1.?{Z3>0} + 0.?{Z3

= DY。 W3.T. ?{Z3 > 0),其中?(.) 是指标函数。使用 非线性 RELU 激活的定义,导数为 1 时 Z3>0,否则为0。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)> insert overwrite table dwd_trade_cart_add_inc > select data.id, > data.user_id, > data.course_id, > date_format(
错误1 hive (edu)> insert into huanhuan values(1,'haoge'); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive> show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 <configuration> <property> <name>yarn.nodemanager.res