如何解决无法从Pytorch-Lightning中的检查点加载模型
我正在使用Pytorch Lightning中的U-Net。我能够成功训练模型,但是在训练后尝试从检查点加载模型时出现此错误:
完全追溯:
Traceback (most recent call last):
File "src/train.py",line 269,in <module>
main(sys.argv[1:])
File "src/train.py",line 263,in main
model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py",line 153,in load_from_checkpoint
model = cls._load_model_state(checkpoint,*args,strict=strict,**kwargs)
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/pytorch_lightning/core/saving.py",line 190,in _load_model_state
model = cls(*cls_args,**cls_kwargs)
File "src/train.py",line 162,in __init__
self.inc = double_conv(self.n_channels,64)
File "src/train.py",line 122,in double_conv
nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1),File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py",line 406,in __init__
super(Conv2d,self).__init__(
File "/home/africa_wikilimo/miniconda3/envs/xarray_test/lib/python3.8/site-packages/torch/nn/modules/conv.py",line 50,in __init__
if in_channels % groups != 0:
TypeError: unsupported operand type(s) for %: 'dict' and 'int'
我尝试浏览github问题和论坛,但无法弄清楚问题是什么。请帮忙。
这是我的模型代码和检查点加载步骤:
型号:
class Unet(pl.LightningModule):
def __init__(self,n_channels,n_classes=5):
super(Unet,self).__init__()
# self.hparams = hparams
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = True
self.logger = WandbLogger(name="Adam",project="pytorchlightning")
def double_conv(in_channels,out_channels):
return nn.Sequential(
nn.Conv2d(in_channels,nn.Batchnorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels,)
def down(in_channels,out_channels):
return nn.Sequential(
nn.MaxPool2d(2),double_conv(in_channels,out_channels)
)
class up(nn.Module):
def __init__(self,in_channels,bilinear=False):
super().__init__()
if bilinear:
self.up = nn.Upsample(
scale_factor=2,mode="bilinear",align_corners=True
)
else:
self.up = nn.ConvTranspose2d(
in_channels // 2,in_channels // 2,kernel_size=2,stride=2
)
self.conv = double_conv(in_channels,out_channels)
def forward(self,x1,x2):
x1 = self.up(x1)
# [?,C,H,W]
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(
x1,[diffX // 2,diffX - diffX // 2,diffY // 2,diffY - diffY // 2]
)
x = torch.cat([x2,x1],dim=1)
return self.conv(x)
self.inc = double_conv(self.n_channels,64)
self.down1 = down(64,128)
self.down2 = down(128,256)
self.down3 = down(256,512)
self.down4 = down(512,512)
self.up1 = up(1024,256)
self.up2 = up(512,128)
self.up3 = up(256,64)
self.up4 = up(128,64)
self.out = nn.Conv2d(64,self.n_classes,kernel_size=1)
def forward(self,x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5,x4)
x = self.up2(x,x3)
x = self.up3(x,x2)
x = self.up4(x,x1)
return self.out(x)
def training_step(self,batch,batch_nb):
x,y = batch
y_hat = self.forward(x)
loss = self.MSE(y_hat,y)
# wandb_logger.log_metrics({"loss":loss})
return {"loss": loss}
def training_epoch_end(self,outputs):
avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
self.logger.log_metrics({"train_loss": avg_train_loss})
return {"average_loss": avg_train_loss}
def test_step(self,y = batch
y_hat = self.forward(x)
loss = self.MSE(y_hat,y)
return {"test_loss": loss,"pred": y_hat}
def test_end(self,outputs):
avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
return {"avg_test_loss": avg_loss}
def MSE(self,logits,labels):
return torch.mean((logits - labels) ** 2)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(),lr=0.1,weight_decay=1e-8)
主要功能:
def main(expconfig):
# Define checkpoint callback
checkpoint_callback = ModelCheckpoint(
filepath="/home/africa_wikilimo/data/model_checkpoint/",save_top_k=1,verbose=True,monitor="loss",mode="min",prefix="",)
# Initialise datasets
print("Initializing climate Dataset....")
clima_train = clima_Dataset(expconfig[0])
# Initialise DataLoaders
print("Initializing train_loader....")
train_DataLoader = DataLoader(clima_train,batch_size=2,num_workers=4)
# Initialise model and trainer
print("Initializing model...")
model = Unet(n_channels=9,n_classes=5)
print("Initializing Trainer....")
if torch.cuda.is_available():
model.cuda()
trainer = pl.Trainer(
max_epochs=1,gpus=1,checkpoint_callback=checkpoint_callback,early_stop_callback=None,)
else:
trainer = pl.Trainer(max_epochs=1,checkpoint_callback=checkpoint_callback)
trainer.fit(model,train_DataLoader=train_DataLoader)
print(checkpoint_callback.best_model_path)
model = Unet.load_from_checkpoint(checkpoint_callback.best_model_path)
解决方法
原因
之所以会发生这种情况,是因为您的模型无法从检查点加载超参数(n_channels,n_classes = 5),因为您没有显式地保存它们。
修复
您可以通过使用Unet类的 init 方法中的self.save_hyperparameters('n_channels','n_classes')
方法来解决此问题。
有关使用此方法的更多详细信息,请参见PyTorch Lightning hyperparams-docs。使用save_hyperparameters可使选定的参数与检查点一起保存在 hparams.yaml 中。
感谢@AdrianWälchli 当我遇到同一问题时,PyTorch Lightning核心贡献者团队的(awaelchli)提出了此修复程序。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。