如何解决迭代torch._C.Graph中的节点
我如何才能准确地找到 PyTorch 模型图中存在的节点,以及它们的输入是什么? 我尝试使用
获取torch._C.Graph
对象
scripted=torch.jit.script(MyModel().eval())
frozen_module = torch.jit.freeze(scripted)
print(frozen_module.inlined_graph)
给出了以下输出
graph(%self : __torch__.___torch_mangle_2.MyModel,%x1.1 : Tensor,%x2.1 : Tensor,%x3.1 : Tensor):
%4 : Float(52229:1,4:52229,requires_grad=0,device=cpu) = prim::Constant[value=<Tensor>]()
%5 : Float(10:1,5:10,device=cpu) = prim::Constant[value=<Tensor>]()
%6 : int[] = prim::Constant[value=[0,0]]()
%7 : int[] = prim::Constant[value=[2,2]]()
%8 : int[] = prim::Constant[value=[1,1]]()
%9 : int = prim::Constant[value=2]()
%10 : bool = prim::Constant[value=0]()
%11 : int = prim::Constant[value=1]() # test.py:39:34
%12 : int = prim::Constant[value=0]() # test.py:39:29
%13 : int = prim::Constant[value=-1]() # test.py:39:33
%self.classifier.bias : Float(4:1,device=cpu) = prim::Constant[value=0.001 * 2.8424 1.0601 -1.3229 4.2920 [ cpuFloatType{4} ]]()
%self.features3.0.bias : Float(5:1,device=cpu) = prim::Constant[value= 0.0111 -0.0702 0.1396 0.1691 0.1335 [ cpuFloatType{5} ]]()
%self.features2.0.bias : Float(3:1,device=cpu) = prim::Constant[value= 0.3314 0.0165 0.2588 [ cpuFloatType{3} ]]()
%self.features2.0.weight : Float(3:9,1:9,3:3,3:1,device=cpu) = prim::Constant[value=<Tensor>]()
%self.features1.0.bias : Float(3:1,device=cpu) = prim::Constant[value=0.01 * 2.5380 -31.8947 -15.3462 [ cpuFloatType{3} ]]()
%self.features1.0.weight : Float(3:9,device=cpu) = prim::Constant[value=<Tensor>]()
%input.4 : Tensor = aten::conv2d(%x1.1,%self.features1.0.weight,%self.features1.0.bias,%8,%11)
%input.6 : Tensor = aten::max_pool2d(%input.4,%7,%6,%10)
%x1.3 : Tensor = aten::relu(%input.6)
%input.7 : Tensor = aten::conv2d(%x2.1,%self.features2.0.weight,%self.features2.0.bias,%11)
%input.8 : Tensor = aten::max_pool2d(%input.7,%10)
%x2.3 : Tensor = aten::relu(%input.8)
%26 : int = aten::dim(%x3.1)
%27 : bool = aten::eq(%26,%9)
%input.3 : Tensor = prim::If(%27)
block0():
%ret.2 : Tensor = aten::addmm(%self.features3.0.bias,%x3.1,%5,%11,%11)
-> (%ret.2)
block1():
%output.2 : Tensor = aten::matmul(%x3.1,%5)
%output.4 : Tensor = aten::add_(%output.2,%self.features3.0.bias,%11)
-> (%output.4)
%x3.3 : Tensor = aten::relu(%input.3)
%33 : int = aten::size(%x1.3,%12)
%34 : int[] = prim::ListConstruct(%33,%13)
%x1.6 : Tensor = aten::view(%x1.3,%34)
%36 : int = aten::size(%x2.3,%12)
%37 : int[] = prim::ListConstruct(%36,%13)
%x2.6 : Tensor = aten::view(%x2.3,%37)
%39 : int = aten::size(%x3.3,%12)
%40 : int[] = prim::ListConstruct(%39,%13)
%x3.6 : Tensor = aten::view(%x3.3,%40)
%42 : Tensor[] = prim::ListConstruct(%x1.6,%x2.6,%x3.6)
%x.1 : Tensor = aten::cat(%42,%11)
%44 : int = aten::dim(%x.1)
%45 : bool = aten::eq(%44,%9)
%x.3 : Tensor = prim::If(%45)
block0():
%ret.1 : Tensor = aten::addmm(%self.classifier.bias,%x.1,%4,%11)
-> (%ret.1)
block1():
%output.1 : Tensor = aten::matmul(%x.1,%4)
%output.3 : Tensor = aten::add_(%output.1,%self.classifier.bias,%11)
-> (%output.3)
return (%x.3)
但是我无法迭代或找到其中存在的节点或它具有的输入究竟是什么。请建议是否有任何其他方式来执行上述操作。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。