如何解决使用 PyTorch 模型执行推理时子进程挂起
我有一个 PyTorch 模型(类 Net
)及其保存的权重/状态字典 (net.pth
),我想在多处理环境中执行推理。
我注意到我不能简单地创建一个模型实例,加载权重,然后与子进程共享模型(尽管我认为这是可能的,因为写时复制)。发生的情况是子进程挂在 y = model(x)
上,最后整个程序挂起(由于父进程的 waitpid
)。
以下是一个可重现的最小示例:
def handler():
with torch.no_grad():
x = torch.rand(1,3,32,32)
y = model(x)
return y
model = Net()
model.load_state_dict(torch.load("./net.pth"))
pid = os.fork()
if pid == 0:
# this doesn't get printed as handler() hangs for the child process
print('child:',handler())
else:
# everything is fine here
print('parent:',handler())
os.waitpid(pid,0)
如果模型加载是为父和子独立完成的,即没有共享,那么一切都按预期进行。我也试过在模型的张量上调用 share_memory_
,但无济于事。
我在这里做错了什么吗?
解决方法
似乎共享状态字典并在每个进程中执行加载操作解决了问题:
LOADED = False
def handler():
global LOADED
if not LOADED:
# each process loads state independently
model.load_state_dict(state)
LOADED = True
with torch.no_grad():
x = torch.rand(1,3,32,32)
y = model(x)
return y
model = Net()
# share the state rather than loading the state dict in parent
# model.load_state_dict(torch.load("./net.pth"))
state = torch.load("./net.pth")
pid = os.fork()
if pid == 0:
print('child:',handler())
else:
print('parent:',handler())
os.waitpid(pid,0)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。