如何解决使用 torch.nn.DataParallel() 时如何访问类对象?
我想使用带有多个 GPU 的 PyTorch 训练我的模型。我包括以下行:
model = torch.nn.DataParallel(model,device_ids=opt.gpu_ids)
然后,我尝试访问在我的模型定义中定义的优化器:
G_opt = model.module.optimizer_G
AttributeError: 'DataParallel' 对象没有属性 optimizer_G
我认为这与我的模型定义中优化器的定义有关。当我在没有 torch.nn.DataParallel
的情况下使用单个 GPU 时,它可以工作。但它不适用于多 GPU,即使我使用 module
调用并且我找不到解决方案。
这是模型定义:
class MyModel(torch.nn.Module):
...
self.optimizer_G = torch.optim.Adam(params,lr=opt.lr,betas=(opt.beta1,0.999))
如果您想查看完整代码,我在 GitHub 中使用了 Pix2PixHD 实现。
谢谢, 最好的。
编辑:我使用 model.module.module.optimizer_G
解决了这个问题。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。