如何解决如何使用 pytorch 闪电剥离预训练网络并为其添加一些层?
我正在尝试将迁移学习用于图像分割任务,我的计划是使用预训练模型的前几层(例如 VGG16)作为编码器,然后添加我自己的解码器。
所以,我可以加载模型并通过打印来查看结构:
model = torch.hub.load('pytorch/vision:v0.6.0','resnet18',pretrained=True)
print(model)
我是这样的:
ResNet(
(conv1): Conv2d(3,64,kernel_size=(7,7),stride=(2,2),padding=(3,3),bias=False)
(bn1): BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3,stride=2,padding=1,dilation=1,ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64,kernel_size=(3,stride=(1,1),padding=(1,bias=False)
(bn1): BatchNorm2d(64,track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64,bias=False)
(bn2): BatchNorm2d(64,track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64,track_running_stats=True)
)
)
.....
.....
.....
例如,我还可以使用 model.layer3
访问特定层。现在,我正在为某些事情而苦苦挣扎。
- 如何对模型进行切割,并从任何层的开始到结束(例如model.layer3)获取每个模块?
- 如何
freeze
仅Action.Submit
这个剥离的部分,并保持新添加的模块可用于培训?
解决方法
以下适用于 model
的任何子模块,但我将在这里用 model.layer3
回答您的问题:
-
一样直接调用它model.layer3
将为您提供与模型第 3 层关联的nn.Module
。您可以像使用model
>>> z = model.layer3(torch.rand(16,128,10,10)) >>> z.shape torch.Size([16,256,5,5])
-
要冻结模型:
-
您可以将层置于 eval 模式,以禁用 dropout 并使 BN 层在训练期间使用统计学习。这是通过
完成的model.layer3.eval()
-
您必须通过切换
requires_grad
标志来禁用该层的训练:model.layer3.requires_grad_(False)
,这将影响所有子参数。
-
您可以使用以下命令冻结图层:
pretrained_model.freeze()
,
对于 1):在您的 LightningModule
中初始化 ResNet 并将其切片直到您需要的部分。然后在此之后添加您自己的头部,并按照您需要的顺序定义 forward
。请参阅此示例,基于 transfer learning docs:
import torchvision.models as models
class ImagenetTransferLearning(LightningModule):
def __init__(self):
super().__init__()
# init a pretrained resnet
backbone_tmp = models.resnet50(pretrained=True)
num_filters = backbone_tmp.fc.in_features
layers = list(backbone_tmp.children())[:-1]
self.backbone = nn.Sequential(*layers)
# use the pretrained model to classify cifar-10 (10 image classes)
num_target_classes = 10
self.classifier = nn.Linear(num_filters,num_target_classes)
对于 2):将 BackboneFinetuning
callback 传递给您的 trainer
。这要求您的 LightningModule
具有 self.backbone
属性,其中包含您要冻结的模块,如上面的代码段所示。如果您需要不同的冻结-解冻行为,您也可以使用 BaseFinetuning
callback。
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import BackboneFinetuning
multiplicative = lambda epoch: 1.5
backbone_finetuning = BackboneFinetuning(200,multiplicative)
trainer = Trainer(callbacks=[backbone_finetuning])
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。