微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 PyTorch 模型执行推理时子进程挂起

如何解决使用 PyTorch 模型执行推理时子进程挂起

我有一个 PyTorch 模型(类 Net)及其保存的权重/状态字典 (net.pth),我想在多处理环境中执行推理。

我注意到我不能简单地创建一个模型实例,加载权重,然后与子进程共享模型(尽管我认为这是可能的,因为写时复制)。发生的情况是子进程挂在 y = model(x) 上,最后整个程序挂起(由于父进程的 waitpid)。

以下是一个可重现的最小示例:

def handler():
    with torch.no_grad():
        x = torch.rand(1,3,32,32)
        y = model(x)

    return y


model = Net()
model.load_state_dict(torch.load("./net.pth"))

pid = os.fork()

if pid == 0:
    # this doesn't get printed as handler() hangs for the child process
    print('child:',handler())
else:
    # everything is fine here
    print('parent:',handler())
    os.waitpid(pid,0)

如果模型加载是为父和子独立完成的,即没有共享,那么一切都按预期进行。我也试过在模型的张量上调用 share_memory_,但无济于事。

在这里做错了什么吗?

解决方法

似乎共享状态字典并在每个进程中执行加载操作解决了问题:

LOADED = False 

def handler():
    global LOADED
    if not LOADED:
        # each process loads state independently
        model.load_state_dict(state)
        LOADED = True

    with torch.no_grad():
        x = torch.rand(1,3,32,32)
        y = model(x)

    return y


model = Net()

# share the state rather than loading the state dict in parent
# model.load_state_dict(torch.load("./net.pth"))
state = torch.load("./net.pth")

pid = os.fork()

if pid == 0:
    print('child:',handler())
else:
    print('parent:',handler())
    os.waitpid(pid,0)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。