微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

从 PyTorch 导出的 ONNX InferenceSession ONNX 模型失败

如何解决从 PyTorch 导出的 ONNX InferenceSession ONNX 模型失败

我正在尝试将自定义 PyTorch 模型导出到 ONNX 以执行推理但没有成功...这里的棘手之处在于我正在尝试使用 基于脚本的导出器,如图所示示例 here 以便从我的模型中调用函数

我可以毫无怨言地导出模型,但是在尝试启动 InferenceSession 时,我收到以下错误

Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from ner.onnx Failed:Type Error: Type parameter (T) bound to different types (tensor(int64) and tensor(float) in node (Concat_1260).

我试图找出该问题的根本原因,它似乎是通过在以下函数中使用 torch.matmul() 生成的(非常讨厌,因为我试图仅使用 pytorch 运算符):

@torch.jit.script
def valid_sequence_output(sequence_output,valid_mask):
    X = torch.where(valid_mask.unsqueeze(-1) == 1,sequence_output,torch.zeros_like(sequence_output))
    bs,max_len,_ = X.shape

    tu = torch.unique(torch.nonzero(X)[:,:2],dim=0)
    batch_axis = tu[:,0]
    rows_axis = tu[:,1]

    a = torch.arange(bs).repeat(batch_axis.shape).reshape(batch_axis.shape[0],-1)
    a = torch.transpose(a,1)

    T = torch.cumsum(torch.where(batch_axis == a,torch.ones_like(a),torch.zeros_like(a)),dim=1) - 1
    cols_axis = T[batch_axis,torch.arange(batch_axis.shape[0])]

    A = torch.zeros((bs,max_len))
    A[(batch_axis,cols_axis,rows_axis)] = 1.0

    valid_output = torch.matmul(A,X)
    valid_attention_mask = torch.where(valid_output[:,:,0] != 0,torch.ones_like(valid_mask),torch.zeros_like(valid_mask))
    return valid_output,valid_attention_mask

似乎不支持 torch.matmul(根据文档),所以我尝试了一系列解决方法(例如 A.matmul(X)torch.baddbmm),但我仍然遇到同样的问题。 ..

有关如何解决此行为的任何建议都很棒:D 感谢您的帮助!

解决方法

这表明存在模型转换问题。请针对 Torch 导出器功能提出问题。类型 (T) 必须绑定到相同的类型才能使模型有效,而 ORT 基本上是在抱怨这一点。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?