如何解决PyTorch Lightning:在检查点文件中包含一些 Tensor 对象
由于 Pytorch Lightning 为模型检查点提供自动保存,我使用它来保存 top-k 最佳模型。特别是在培训师设置中,
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',dirpath='checkpoints/',filename='{epoch:02d}-{val_acc:.2f}',save_top_k=5,mode='max',)
这运行良好,但它没有保存模型对象的某些属性。我的模型在每个训练时期结束时存储一些张量,使得
class SampleNet(pl.LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.layer = torch.nn.Linear(100,1)
self.loss = torch.nn.CrossEntropy()
self.some_data = None # Initialize as None
def training_step(self,batch):
x,t = batch
out = self.layer(x)
loss = self.loss(out,t)
results = {'loss': loss}
return results
def training_epoch_end(self,outputs):
self.some_data = some_tensor_object
这是一个简化的示例,但我希望上面 checkpoint_callback
制作的检查点文件记住属性 self.some_data
,但是当我从检查点加载模型时,它总是重置为 None
。我在培训期间确认它已成功更新。
我尝试不在 init
中将其初始化为 None ,但是加载模型时该属性将消失。
将属性保存为不同的 pt
文件是我想要避免的事情,因为它与模型配置相关联,因此我稍后需要手动将该文件与相应的检查点文件进行匹配。
是否可以在检查点文件中包含这样的张量属性?
解决方法
这似乎不可能直接,因为提取最有可能使用 nn.Module.state_dict()
的参数。
该方法仅提取实际被视为参数的张量的值。因此,在这种情况下,一种解决方法是将您的数据保存为参数(请参阅 docs):
self.some_data = torch.nn.parameter.Parameter(your_data)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。