微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何知道Pytorch模型的输入/输出层名称和大小?

如何解决如何知道Pytorch模型的输入/输出层名称和大小?

我使用Detectron2's COCO对象检测基线预训练的模型R50-FPN具有Pytorch model.pth。 我正在尝试转换.pth model to onnx

我的代码如下。

import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
from torchvision import models

model = torch.load('output_object_detection/model_final.pth')
x = torch.randn(1,3,1080,1920,requires_grad=True)#0,in_cha,in_h,in_w
torch_out = torch_model(x)
print(model)
torch.onnx.export(torch_model,# model being run
                  x,# model input (or a tuple for multiple inputs)
                  "super_resolution.onnx",# where to save the model (can be a file or file-like object)
                  export_params=True,# store the trained parameter weights inside the model file
                  opset_version=10,# the ONNX version to export the model to
                  do_constant_folding=True,# whether to execute constant folding for optimization
                  input_names = ['input'],# the model's input names
                  output_names = ['cls_score','bBox_pred'],# the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},# variable lenght axes
                                'output' : {0 : 'batch_size'}})

转换ONNX模型是否正确? 如果是正确的方法,如何知道输入名称输出名称

使用netron查看输入和输出,但该图未显示输入/输出层。

解决方法

您可以尝试一次吗?
让我知道

import io
import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
from torchvision import models    
model = torch.load('model_final.pth')
model.eval()
print('Finished loading model!')
print(model)
device = torch.device("cpu" if args.cpu else "cuda")
model = model.to(device)

# ------------------------ export -----------------------------
output_onnx = 'super_resolution.onnx'
print("==> Exporting model to ONNX format at '{}'".format(output_onnx))
input_names = ["input0"]
output_names = ["output0","output1"]
inputs = torch.randn(1,3,1080,1920).to(device)

torch_out = torch.onnx._export(model,inputs,output_onnx,export_params=True,verbose=False,input_names=input_names,output_names=output_names)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?