微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用 Pycoral 库和 Google Coral USB 加速器进行二进制图像分类

如何解决使用 Pycoral 库和 Google Coral USB 加速器进行二进制图像分类

我有一个 Keras 模型,它使用 sigmoid 函数进行二元分类。我按照 Coral USB 的要求将我的模型编译为 .tflite 格式以运行推理。但是,我注意到脚本 classify_image.py 执行多类分类。因此,当我尝试对图像进行分类时,我对任何图像都得到了 100% 的预测。例如,我的模型将红外图像分类为发烧状态。即使我传球图像,它也会为发烧等级提供 100% 阳性。

因此,我再次使用完全自定义的模型对植物使用 layer softmax 测试了多类模型,这一次它起作用了。它为植物 A、植物 B 和植物 C 提供了合理的 85% 准确率。

因此,我想知道我需要做哪些更改,才能使用二进制分类自定义模型与 Pycoral 配合使用。

这是我用于分类代码

import argparse
import time

from PIL import Image
from pycoral.adapters import classify
from pycoral.adapters import common
from pycoral.utils.dataset import read_label_file
from pycoral.utils.edgetpu import make_interpreter
import cv2 as cv
import numpy as np


def main():
  parser = argparse.ArgumentParser(
      formatter_class=argparse.ArgumentDefaultsHelpformatter)
  parser.add_argument('-m','--model',required=True,help='File path of .tflite file.')
  parser.add_argument('-i','--input',help='Image to be classified.')
  parser.add_argument('-l','--labels',help='File path of labels file.')
  parser.add_argument('-k','--top_k',type=int,default=2,help='Max number of classification results')
  parser.add_argument('-t','--threshold',type=float,default=0.0,help='Classification score threshold')
  parser.add_argument('-c','--count',default=5,help='Number of times to run inference')
  args = parser.parse_args()

  labels = read_label_file(args.labels) if args.labels else {}

  interpreter = make_interpreter(*args.model.split('@'))
  interpreter.allocate_tensors()
  print(interpreter)

  size = common.input_size(interpreter)
  image = cv.imread(args.input)
  image = cv.normalize(image,image,255,cv.norM_MINMAX)
  common.set_input(interpreter,image)

  print('----INFERENCE TIME----')
  print('Note: The first inference on Edge TPU is slow because it includes','loading the model into Edge TPU memory.')
  for _ in range(args.count):
    start = time.perf_counter()
    interpreter.invoke()
    inference_time = time.perf_counter() - start
    classes = classify.get_classes(interpreter,args.top_k,args.threshold)
    print('%.1fms' % (inference_time * 1000))

  print('-------RESULTS--------')
  for c in classes:
    print('%s: %.5f' % (labels.get(c.id,c.id),c.score))


if __name__ == '__main__':
  main()

我的 labels.txt 有两个标签,只有发烧(正类)和健康(负类)。用于二元模型分类的阈值为 0.50,据我所知,模型层与 Coral USB Accelerator 设备完全兼容。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。