微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何使用 RoBERTa ONNX 量化模型执行批量推理?

如何解决如何使用 RoBERTa ONNX 量化模型执行批量推理?

我已将 RoBERTa PyTorch 模型转换为 ONNX 模型并对其进行量化。我能够从 ONNX 模型中获得单个输入数据点(每个句子)的分数。我想了解如何通过将多个输入传递给会话来使用 ONNX 运行时推理会话进行批量预测。下面是示例场景。

模型:roberta-quant.onnx,它是 RoBERTa PyTorch 模型的 ONNX 量化版本

用于将 RoBERTa 转换为 ONNX 的代码

torch.onnx.export(model,args=tuple(inputs.values()),# model input 
                      f=export_model_path,# where to save the model 
                      opset_version=11,# the ONNX version to export the model to
                      do_constant_folding=True,# whether to execute constant folding for optimization
                      input_names=['input_ids',# the model's input names
                                   'attention_mask'],output_names=['output_0'],# the model's output names
                      dynamic_axes={'input_ids': symbolic_names,# variable length axes
                                    'attention_mask' : symbolic_names,'output_0' : {0: 'batch_size'}})

ONNXRuntime 推理会话的输入示例:

{
     'input_ids': array([[    0,510,35,21071,.....,1,1]]),'attention_mask': array([[1,.......,0]])
}

使用 ONNXRuntime 推理会话为 400 个数据样本(句子)运行 ONNX 模型:

session = onnxruntime.InferenceSession("roberta_quantized.onnx",providers=['cpuExecutionProvider'])
for i in range(400):
   ort_inputs = {
    'input_ids':  input_ids[i].cpu().reshape(1,max_seq_length).numpy(),# max_seq_length=128 here
    'input_mask': attention_masks[i].cpu().reshape(1,max_seq_length).numpy()
   }

   ort_outputs = session.run(None,ort_inputs)

在上面的代码中,我依次循环遍历 400 个句子以获得分数“ort_outputs”。请帮助我理解如何使用 ONNX 模型执行批处理在这里我可以发送 inputs_idsattention_masks 用于多个句子并获得所有句子的分数ort_outputs

提前致谢!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?