关于使用 PyTorch 编程 cnn 的问题

如何解决关于使用 PyTorch 编程 cnn 的问题

我对 cnn 编程还很陌生,所以我有点迷茫。我正在尝试执行这部分代码,他们要求我实现一个完全连接的网络来对数字进行分类。它应该包含 1 个具有 20 个单元的隐藏层。我应该在隐藏层上使用 ReLU 激活函数。

class Network(nn.Module):
    def __init__(self):
        super(Network,self).__init__()
        self.fc1 = ... 
        
        self.fc2 = nn.Sequential(
            nn.Linear(500,10),nn.Softmax(dim = 1)
            )
        
    def forward(self,x):
        x = x.view(x.size(0),-1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

点是要填充的部分,我想到了这一行:

self.fc1 = nn.Linear(20,500)

但我不知道它是否正确。有人可以帮我吗?而且我完全不明白 Softmax 的功能是做什么的……所以如果有人知道的话。 非常感谢!!

钯。这是加载数据的代码:

batch_size = 64
trainset = datasets.MNIST('./data',train=True,download=True,transform=transforms.ToTensor())
train_loader = DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=1)
testset = datasets.MNIST('./data',train=False,transform=transforms.ToTensor())
test_loader = DataLoader(testset,shuffle=False,num_workers=1)

解决方法

从模型给出的代码可以看出,隐藏层有500个单元。所以我假设你的意思是输入 20 个单位。有了这个假设,代码必须是:

self.fc1 = nn.Sequential(
    nn.Linear(20,500),nn.ReLU()
    )

进入问题的下一部分,鉴于您正在使用 MNIST 数据集并且您拥有 softmax 函数,我假设您正在尝试预测图像中存在的数量。 您的神经网络在每一层执行各种乘法和加法运算,最后,您在输出层得到 10 个数字。现在,您必须理解这 10 个数字才能决定图像中给出的 10 个数字中的哪一个。

一种方法是选择具有最大值的单位。例如,如果第 10 个单位在所有单位中具有最大值,则我们得出结论该数字为“9”。如果第 2 个单位具有最大值,则我们得出结论,该数字为“1”。

这很好,但更好的方法是将每个单位的值转换为图像中包含相应数字的概率,然后我们选择概率最高的数字。这具有一定的数学优势,有助于我们定义更好的损失函数。

Softmax 帮助我们将值转换为概率。在应用 softmax 时,所有值都在 (0,1) 范围内,并且它们总和为 1。

如果您对深度学习及其背后的数学感兴趣,我建议您查看 Andrew NG 的深度学习课程。

,

您没有提及数据的形状,因此我将假设 datasets.MNIST 返回的预期形状。

数据形状:torch.Size([64,1,28,28])

class Network(nn.Module):
    def __init__(self):
        super(Network,self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(1*28*28,20),nn.ReLU())
        
        self.fc2 = nn.Sequential(
            nn.Linear(500,10),nn.Softmax(dim = 1))
        
    def forward(self,x):
        x = x.view(x.size(0),-1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

nn.Linear 的第一个参数是输入特征的大小,第二个参数是单位数

对于 self.fc1,输入特征的大小是除批量大小之外的数据形状的乘积,即 1 * 28 * 28。根据您的帖子,第二个参数应该是 20(20 个单位)。

self.fc1 的输出(也是 self.fc2 的输入)的形状将是 (batch size,20)

对于self.fc2,输入特征的大小将为20,而单位数(也是位数)为10

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)> insert overwrite table dwd_trade_cart_add_inc > select data.id, > data.user_id, > data.course_id, > date_format(
错误1 hive (edu)> insert into huanhuan values(1,'haoge'); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive> show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 <configuration> <property> <name>yarn.nodemanager.res