微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

迭代torch._C.Graph中的节点

如何解决迭代torch._C.Graph中的节点

我如何才能准确地找到 PyTorch 模型图中存在的节点,以及它们的输入是什么? 我尝试使用

获取 torch._C.Graph 对象

scripted=torch.jit.script(MyModel().eval())
frozen_module = torch.jit.freeze(scripted)
print(frozen_module.inlined_graph)

给出了以下输出

graph(%self : __torch__.___torch_mangle_2.MyModel,%x1.1 : Tensor,%x2.1 : Tensor,%x3.1 : Tensor):
  %4 : Float(52229:1,4:52229,requires_grad=0,device=cpu) = prim::Constant[value=<Tensor>]()
  %5 : Float(10:1,5:10,device=cpu) = prim::Constant[value=<Tensor>]()
  %6 : int[] = prim::Constant[value=[0,0]]()
  %7 : int[] = prim::Constant[value=[2,2]]()
  %8 : int[] = prim::Constant[value=[1,1]]()
  %9 : int = prim::Constant[value=2]() 
  %10 : bool = prim::Constant[value=0]()
  %11 : int = prim::Constant[value=1]() # test.py:39:34
  %12 : int = prim::Constant[value=0]() # test.py:39:29
  %13 : int = prim::Constant[value=-1]() # test.py:39:33
  %self.classifier.bias : Float(4:1,device=cpu) = prim::Constant[value=0.001 *  2.8424  1.0601 -1.3229  4.2920 [ cpuFloatType{4} ]]()
  %self.features3.0.bias : Float(5:1,device=cpu) = prim::Constant[value= 0.0111 -0.0702  0.1396  0.1691  0.1335 [ cpuFloatType{5} ]]()
  %self.features2.0.bias : Float(3:1,device=cpu) = prim::Constant[value= 0.3314  0.0165  0.2588 [ cpuFloatType{3} ]]()
  %self.features2.0.weight : Float(3:9,1:9,3:3,3:1,device=cpu) = prim::Constant[value=<Tensor>]()
  %self.features1.0.bias : Float(3:1,device=cpu) = prim::Constant[value=0.01 *  2.5380 -31.8947 -15.3462 [ cpuFloatType{3} ]]()
  %self.features1.0.weight : Float(3:9,device=cpu) = prim::Constant[value=<Tensor>]()
  %input.4 : Tensor = aten::conv2d(%x1.1,%self.features1.0.weight,%self.features1.0.bias,%8,%11) 
  %input.6 : Tensor = aten::max_pool2d(%input.4,%7,%6,%10) 
  %x1.3 : Tensor = aten::relu(%input.6) 
  %input.7 : Tensor = aten::conv2d(%x2.1,%self.features2.0.weight,%self.features2.0.bias,%11) 
  %input.8 : Tensor = aten::max_pool2d(%input.7,%10) 
  %x2.3 : Tensor = aten::relu(%input.8) 
  %26 : int = aten::dim(%x3.1) 
  %27 : bool = aten::eq(%26,%9) 
  %input.3 : Tensor = prim::If(%27) 
    block0():
      %ret.2 : Tensor = aten::addmm(%self.features3.0.bias,%x3.1,%5,%11,%11) 
      -> (%ret.2)
    block1():
      %output.2 : Tensor = aten::matmul(%x3.1,%5) 
      %output.4 : Tensor = aten::add_(%output.2,%self.features3.0.bias,%11) 
      -> (%output.4)
  %x3.3 : Tensor = aten::relu(%input.3) 
  %33 : int = aten::size(%x1.3,%12) 
  %34 : int[] = prim::ListConstruct(%33,%13)
  %x1.6 : Tensor = aten::view(%x1.3,%34) 
  %36 : int = aten::size(%x2.3,%12) 
  %37 : int[] = prim::ListConstruct(%36,%13)
  %x2.6 : Tensor = aten::view(%x2.3,%37) 
  %39 : int = aten::size(%x3.3,%12) 
  %40 : int[] = prim::ListConstruct(%39,%13)
  %x3.6 : Tensor = aten::view(%x3.3,%40)
  %42 : Tensor[] = prim::ListConstruct(%x1.6,%x2.6,%x3.6)
  %x.1 : Tensor = aten::cat(%42,%11) 
  %44 : int = aten::dim(%x.1)
  %45 : bool = aten::eq(%44,%9) 
  %x.3 : Tensor = prim::If(%45) 
    block0():
      %ret.1 : Tensor = aten::addmm(%self.classifier.bias,%x.1,%4,%11) 
      -> (%ret.1)
    block1():
      %output.1 : Tensor = aten::matmul(%x.1,%4) 
      %output.3 : Tensor = aten::add_(%output.1,%self.classifier.bias,%11) 
      -> (%output.3)
  return (%x.3)

但是我无法迭代或找到其中存在的节点或它具有的输入究竟是什么。请建议是否有任何其他方式来执行上述操作。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。