如何解决如何理解tensorflow.js对象检测张量输出?
我的动机是构建一个自定义的异物检测Web应用程序。我从model zoo下载了一个tf2
预训练的SSD Resnet1010
模型。我的想法是,如果该实现可行,我将使用自己的数据训练模型。我运行$saved_model_cli show --dir saved_model --tag_set serve --signature_def serving_default
来找出输入和输出节点。
The given SavedModel SignatureDef contains the following input(s):
inputs['input_tensor'] tensor_info:
dtype: DT_UINT8
shape: (1,-1,3)
name: serving_default_input_tensor:0
The given SavedModel SignatureDef contains the following output(s):
outputs['detection_anchor_indices'] tensor_info:
dtype: DT_FLOAT
shape: (1,100)
name: StatefulPartitionedCall:0
outputs['detection_boxes'] tensor_info:
dtype: DT_FLOAT
shape: (1,100,4)
name: StatefulPartitionedCall:1
outputs['detection_classes'] tensor_info:
dtype: DT_FLOAT
shape: (1,100)
name: StatefulPartitionedCall:2
outputs['detection_multiclass_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1,91)
name: StatefulPartitionedCall:3
outputs['detection_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1,100)
name: StatefulPartitionedCall:4
outputs['num_detections'] tensor_info:
dtype: DT_FLOAT
shape: (1)
name: StatefulPartitionedCall:5
outputs['raw_detection_boxes'] tensor_info:
dtype: DT_FLOAT
shape: (1,51150,4)
name: StatefulPartitionedCall:6
outputs['raw_detection_scores'] tensor_info:
dtype: DT_FLOAT
shape: (1,91)
name: StatefulPartitionedCall:7
Method name is: tensorflow/serving/predict
然后我通过运行将模型转换为tensorflowjs模型
tensorflowjs_converter --input_format=tf_saved_model --output_node_names='detection_anchor_indices,detection_boxes,detection_classes,detection_multiclass_scores,detection_scores,num_detections,raw_detection_boxes,raw_detection_scores' --saved_model_tags=serve --output_format=tfjs_graph_model saved_model js_model
这是我的JavaScript代码(在vue方法中)
loadTfModel: async function(){
try {
this.model = await tf.loadGraphModel(this.MODEL_URL);
} catch(error) {
console.log(error);
}
},predictImg: async function() {
const imgData = document.getElementById('img');
let tf_img = tf.browser.fromPixels(imgData);
tf_img = tf_img.expandDims(0);
const predictions = await this.model.executeAsync(tf_img);
const data = []
for (let i = 0; i < predictions.length; i++){
data.push(predictions[i].dataSync());
}
console.log(data);
}
我的问题是数组中的这八个项目是否对应于八个已定义的输出节点?如何理解这些数据?以及如何将其转换为人类可读的格式(例如python)?
更新1:
我已经尝试过answer,并编辑了我的预测方法:
predictImg: async function() {
const imgData = document.getElementById('img');
let tf_img = tf.browser.fromPixels(imgData);
tf_img = tf_img.expandDims(0);
const predictions = await this.model.executeAsync(tf_img,['detection_classes']).then(predictions => {
const data = predictions.dataSync()
console.log('Predictions: ',data);
})
}
我最终得到"Error: The output 'detection_classes' is not found in the graph"
。我将不胜感激。
解决方法
在 {
"dialogAction": {
"type": "Close","fulfillmentState": "Fulfilled","message": {
"contentType": "PlainText","content": "Thanks,your pizza has been ordered."
},"responseCard": {
"version": integer-value,"contentType": "application/vnd.amazonaws.card.generic","genericAttachments": [
{
"title":"card-title","subTitle":"card-sub-title","imageUrl":"URL of the image to be shown","attachmentLinkUrl":"URL of the attachment to be associated with the card","buttons":[
{
"text":"button-text","value":"Value sent to server on button click"
}
]
}
]
}
}
}
中指定的输出节点可能存在错误。另外,这里this.model.executeAsync(tf_img,['detection_classes'])
不需要使用await
。使用await this.model.executeAsync(tf_img,['detection_classes'])
或使用await
。
获取then
的另一种方法是索引输出数组:
detection_classes
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。