如何解决中断后如何恢复训练 pl.Trainer?
我有 Model 和 Trainer pytorch-lightning 对象,它们初始化如下:
checkpoint_callback = ModelCheckpoint(
filepath=os.path.join('experiments',experiment_name,'{epoch}-{avg_valid_iou:.4f}'),save_top_k=1,verbose=True,monitor='avg_valid_iou',mode='max',prefix=''
)
model = nn.DataParallel (FaultNetPL(batch_size=5)).cuda()
trainer = Trainer( checkpoint_callback=checkpoint_callback,max_epochs=500,gpus=1,logger=logger)
然后我开始使用:
trainer.fit(model)
但是训练被中断了,现在我想使用第 N 次迭代的检查点来恢复它 所以我尝试将模型和训练器初始化为:
model = FaultNetPL.load_from_checkpoint('experiments/VNET/epoch=77-avg_valid_iou=0.7604.ckpt',batch_size=5)
trainer = Trainer(resume_from_checkpoint = 'epoch=77-avg_valid_iou=0.7604.ckpt',checkpoint_callback=checkpoint_callback,logger=logger)
但是一次又一次地从头开始训练(从第 0 个纪元开始,错误巨大)。我错过了什么?
解决方法
您不仅应该保存模型状态,还应该保存优化器状态和起始时期值。例如:
state({
'epoch': epoch + 1,'state_dict': model.module.state_dict(),'optimizer': optimizer.state_dict(),})
保存检查点后,您可以通过以下命令继续训练:
checkpoint = torch.load(state_file)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_val = checkpoint['epoch']
for epoch in range(start_val,max_val):
...
...
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。