微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在不改变卷积层权重的情况下更新分类层

如何解决如何在不改变卷积层权重的情况下更新分类层

我有一个带有许多卷积层的 CNN。对于这些卷积层中的每一个,我都附加了一个分类器,以检查中间层的输出。在为这些分类器中的每一个产生损失后,我想更新分类器的权重而不触及卷积层的权重。这段代码

for i in range(len(loss_per_layer)):
    loss_per_layer[i].backward(retain_graph=True)
    self.classifiers[i].weight.data -= self.learning_rate * self.alpha[i] * self.classifiers[i].weight.grad.data
    self.classifiers[i].bias.data -= self.learning_rate * self.alpha[i] * self.classifiers[i].bias.grad.data

如果分类器由单个 nn.Linear 层组成,则允许我这样做。但是,我的分类器的形状如下:

self.classifiers.append(nn.Sequential(
    nn.Linear(int(feature_map * input_shape[1] * input_shape[2]),100),nn.ReLU(True),nn.Dropout(),nn.Linear(100,self.num_classes),nn.Sigmoid(),))

如何在不影响网络其余部分的情况下更新 Sequential 块的权重?我最近从 keras 更改为 pytorch,因此不确定如何在这种情况下准确地使用 optimizer.step() 函数,但我怀疑可以使用它来完成。

请注意,我需要一个适用于任何形状的 Sequential 块的通用解决方案,因为它会在模型​​的未来迭代中发生变化。

非常感谢任何帮助。

解决方法

您可以按如下方式实现您的模型:

class Model(nn.Module):
       def __init__(self,conv_layers,classifier):
           super().__init__()
           self.conv_layers = conv_layers
           self.classifier = classifier

       def forward(self,x):
           x = self.conv_layers(x)
           return self.classifier(x)

在声明优化器时,只传递你想要更新的参数。

       model = Model(conv_layers,classifier)            
       optimizer = torch.optim.Adam(model.classifier.parameters(),lr=lr)
  

现在你什么时候做

       loss.backward()
       optimizer.step()
       model.zero_grad()

仅更新分类器参数。

编辑:在 OP 发表评论后,我将在下面添加更多通用用例。

更通用的场景

   class Model(nn.Module):
       def __init__(self,modules):
           super().__init__()
           # supposing you have multiple modules declared like below. 
           # You can also keep them as an array or dict too. 
           # For this see nn.ModuleList or nn.ModuleDict in pytorch 
           self.module0 = modules[0]
           self.module1 = modules[1]
           #..... and so on

       def forward(self,x):
           # implement forward 

   # model and optimizer declarations
   model = Model(modules)  
   # assuming we want to update module0 and module1          
   optimizer = torch.optim.Adam([
       model.module0.parameters(),model.module1.parameters()
   ],lr=lr)
   # you can also provide different learning rate for different modules. 
   # See [documentation][https://pytorch.org/docs/stable/optim.html]
   
   # when training  
   loss.backward()
   optimizer.step()
   model.zero_grad()
   # use model.zero_grad to remove gradient computed for all the modules.
   # optimizer.zero_grad only removes gradient for parameters that were passed to it. 
,

如果您使用的是内置或自定义 torch.optim.Optimizer,则无需手动执行参数更新。您可以简单地冻结不想更新的图层(通过停用其 requires_grad 标志)。在损失时调用 .backward(),然后 optimizer.step() 只会更新分类器。

根据您的 torch.nn.Module 模型架构,您可以这样做:

for param in model.feature_extractor.parameters():
    param.requires_grad = False

其中 model.feature_extractor 将是包含卷积层的模型的头部(特征提取器)。您可以通过这种方式在任何模块上循环,.parameters() 将循环此模块子参数的所有参数以停用 requires_grad

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?