如何实时检测对象并自动跟踪它,而不是用户必须在要跟踪的对象周围绘制边界框?

如何解决如何实时检测对象并自动跟踪它,而不是用户必须在要跟踪的对象周围绘制边界框?

我有以下代码用户可以在其中按 p 暂停视频,在要跟踪的对象周围绘制一个边界框,然后按 Enter(回车)以在视频源中跟踪该对象:

import cv2
import sys

major_ver,minor_ver,subminor_ver = cv2.__version__.split('.')

if __name__ == '__main__' :

    # Set up tracker.
    tracker_types = ['BOOSTING','MIL','kcf','TLD','MEDIANFLOW','GOTURN','MOSSE','CSRT']
    tracker_type = tracker_types[1]

    if int(minor_ver) < 3:
        tracker = cv2.Tracker_create(tracker_type)
    else:
        if tracker_type == 'BOOSTING':
            tracker = cv2.TrackerBoosting_create()
        if tracker_type == 'MIL':
            tracker = cv2.TrackerMIL_create()
        if tracker_type == 'kcf':
            tracker = cv2.Trackerkcf_create()
        if tracker_type == 'TLD':
            tracker = cv2.TrackerTLD_create()
        if tracker_type == 'MEDIANFLOW':
            tracker = cv2.TrackerMedianFlow_create()
        if tracker_type == 'GOTURN':
            tracker = cv2.TrackerGOTURN_create()
        if tracker_type == 'MOSSE':
            tracker = cv2.TrackerMOSSE_create()
        if tracker_type == "CSRT":
            tracker = cv2.TrackerCSRT_create()

    # Read video
    video = cv2.VideoCapture(0) # 0 means webcam. Otherwise if you want to use a video file,replace 0 with "video_file.MOV")

    # Exit if video not opened.
    if not video.isOpened():
        print ("Could not open video")
        sys.exit()

    while True:

        # Read first frame.
        ok,frame = video.read()
        if not ok:
            print ('Cannot read video file')
            sys.exit()
        
        # Retrieve an image and display it.
        if((0xFF & cv2.waitKey(10))==ord('p')): # Press key `p` to pause the video to start tracking
            break
        cv2.namedWindow("Image",cv2.WINDOW_norMAL)
        cv2.imshow("Image",frame)
    cv2.destroyWindow("Image");

    # select the bounding Box
    bBox = (287,23,86,320)

    # Uncomment the line below to select a different bounding Box
    bBox = cv2.selectROI(frame,False)

    # Initialize tracker with first frame and bounding Box
    ok = tracker.init(frame,bBox)

    while True:
        # Read a new frame
        ok,frame = video.read()
        if not ok:
            break
        
        # Start timer
        timer = cv2.getTickCount()

        # Update tracker
        ok,bBox = tracker.update(frame)

        # Calculate Frames per second (FPS)
        fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer);

        # Draw bounding Box
        if ok:
            # Tracking success
            p1 = (int(bBox[0]),int(bBox[1]))
            p2 = (int(bBox[0] + bBox[2]),int(bBox[1] + bBox[3]))
            cv2.rectangle(frame,p1,p2,(255,0),2,1)
        else :
            # Tracking failure
            cv2.putText(frame,"Tracking failure detected",(100,80),cv2.FONT_HERShey_SIMPLEX,0.75,(0,255),2)

        # display tracker type on frame
        cv2.putText(frame,tracker_type + " Tracker",20),(50,170,50),2);
    
        # display FPS on frame
        cv2.putText(frame,"FPS : " + str(int(fps)),2);

        # display result
        cv2.imshow("Tracking",frame)

        # Exit if ESC pressed
        k = cv2.waitKey(1) & 0xff
        if k == 27 : break

现在,不是让用户暂停视频并在对象周围绘制边界框,而是如何使它能够自动检测我感兴趣的特定对象(在我的情况下是牙刷)在视频源中引入,然后跟踪?

我找到了 this文章,其中讨论了我们如何使用 ImageAI 和 Yolo 检测视频中的对象。

from imageai.Detection import VideoObjectDetection
import os
import cv2

execution_path = os.getcwd()

camera = cv2.VideoCapture(0) 

detector = VideoObjectDetection()
detector.setModelTypeAsYOlov3()
detector.setModelPath(os.path.join(execution_path,"yolo.h5"))
detector.loadModel()

video_path = detector.detectObjectsFromVideo(camera_input=camera,output_file_path=os.path.join(execution_path,"camera_detected_1"),frames_per_second=29,log_progress=True)
print(video_path)

现在,Yolo 确实可以检测牙刷,它是它认可以检测的 80 多种物体之一。但是,这篇文章有两点使它对我来说不是理想的解决方案:

  1. 方法首先分析每个视频帧(每帧大约需要 1-2 秒,所以大约需要 1 分钟来分析来自网络摄像头的 2-3 秒视频流),并将检测到的视频保存在单独的视频中文件。而我想实时检测网络摄像头视频源中的牙刷。有解决办法吗?

  2. 正在使用的 Yolo v3 模型可以检测所有 80 个对象,但我只希望检测 2 或 3 个对象 - 牙刷、拿着牙刷的人和背景,如果需要的话。那么,有没有一种方法可以通过仅选择这 2 或 3 个对象进行检测来减少模型权重?

解决方法

如果您想要一个快速简便的解决方案,您可以使用更轻量级的 yolo 文件之一。你可以从这个网站获取权重和配置文件(它们成对出现,必须一起使用):https://pjreddie.com/darknet/yolo/(别担心,它看起来是草图,但很好)

使用较小的网络会让您获得更高的 fps,但也会降低准确性。如果这是您愿意接受的权衡,那么这是最简单的做法。

这是一些检测牙刷的代码。第一个文件只是一个类文件,有助于更无缝地使用 Yolo 网络。第二个是打开 VideoCapture 并将图像提供给网络的“主”文件。

yolo.py

import cv2
import numpy as np

class Yolo:
    def __init__(self,cfg,weights,names,conf_thresh,nms_thresh,use_cuda = False):
        # save thresholds
        self.ct = conf_thresh;
        self.nmst = nms_thresh;

        # create net
        self.net = cv2.dnn.readNet(weights,cfg);
        print("Finished: " + str(weights));
        self.classes = [];
        file = open(names,'r');
        for line in file:
            self.classes.append(line.strip());

        # use gpu + CUDA to speed up detections
        if use_cuda:
            self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA);
            self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA);

        # get output names
        layer_names = self.net.getLayerNames();
        self.output_layers = [layer_names[i[0]-1] for i in self.net.getUnconnectedOutLayers()];

    # runs detection on the image and draws on it
    def detect(self,img,target_id):
        # get detection stuff
        b,c,ids,idxs = self.get_detection_data(img,target_id);

        # draw result
        img = self.draw(img,b,idxs);
        return img,len(idxs);

    # returns boxes,confidences,class_ids,and indexes (indices?)
    def get_detection_data(self,target_id):
        # get output
        layer_outputs = self.get_inf(img);

        # get dims
        height,width = img.shape[:2];

        # filter thresholds and target
        b,idxs = self.thresh(layer_outputs,width,height,target_id);
        return b,idxs;

    # runs the network on an image
    def get_inf(self,img):
        # construct a blob
        blob = cv2.dnn.blobFromImage(img,1 / 255.0,(416,416),swapRB=True,crop=False);

        # get response
        self.net.setInput(blob);
        layer_outputs = self.net.forward(self.output_layers);
        return layer_outputs;

    # filters the layer output by conf,nms and id
    def thresh(self,layer_outputs,target_id):
        # some lists
        boxes = [];
        confidences = [];
        class_ids = [];

        # each layer outputs
        for output in layer_outputs:
            for detection in output:
                # get id and confidence
                scores = detection[5:];
                class_id = np.argmax(scores);
                confidence = scores[class_id];

                # filter out low confidence
                if confidence > self.ct and class_id == target_id:
                    # scale bounding box back to the image size
                    box = detection[0:4] * np.array([width,height]);
                    (cx,cy,w,h) = box.astype('int');

                    # grab the top-left corner of the box
                    tx = int(cx - (w / 2));
                    ty = int(cy - (h / 2));

                    # update lists
                    boxes.append([tx,ty,int(w),int(h)]);
                    confidences.append(float(confidence));
                    class_ids.append(class_id);

        # apply NMS
        idxs = cv2.dnn.NMSBoxes(boxes,self.ct,self.nmst);
        return boxes,idxs;

    # draw detections on image
    def draw(self,boxes,idxs):
        # check for zero
        if len(idxs) > 0:
            # loop over indices
            for i in idxs.flatten():
                # extract the bounding box coords
                (x,y) = (boxes[i][0],boxes[i][1]);
                (w,h) = (boxes[i][2],boxes[i][3]);

                # draw a box
                cv2.rectangle(img,(x,y),(x+w,y+h),(0,255),2);

                # draw text
                text = "{}: {:.4}".format(self.classes[class_ids[i]],confidences[i]);
                cv2.putText(img,text,y-5),cv2.FONT_HERSHEY_SIMPLEX,0.5,2);
        return img;

main.py

import cv2
import numpy as np

# this is the "yolo.py" file,I assume it's in the same folder as this program
from yolo import Yolo

# these are the filepaths of the yolo files
weights = "yolov3-tiny.weights";
config = "yolov3-tiny.cfg";
labels = "yolov3.txt";

# init yolo network
target_class_id = 79; # toothbrush
conf_thresh = 0.4; # less == more boxes (but more false positives)
nms_thresh = 0.4; # less == more boxes (but more overlap)
net = Yolo(config,labels,nms_thresh);

# open video capture
cap = cv2.VideoCapture(0);

# loop
done = False;
while not done:
    # get frame
    ret,frame = cap.read();
    if not ret:
        done = cv2.waitKey(1) == ord('q');
        continue;

    # do detection
    frame,_ = net.detect(frame,target_class_id);

    # show
    cv2.imshow("Marked",frame);
    done = cv2.waitKey(1) == ord('q');

如果您不想使用较轻的权重文件,有几个选项可供您选择。

如果您有 Nvidia GPU,您可以使用 CUDA大幅提高您的 fps。即使是普通的 nvidia gpu 也比仅在 cpu 上运行快几倍。

绕过持续运行检测成本的常见策略是仅使用它来最初获取目标。您可以使用来自神经网络的检测来初始化您的对象跟踪器,类似于一个人在对象周围绘制边界框。对象跟踪器的速度要快得多,而且无需每帧都进行全面检测。

如果您在单独的线程中运行 Yolo 和对象跟踪,那么您可以尽可能快地运行相机。您需要存储帧的历史记录,以便当 Yolo 线程完成一帧时,您可以检查旧帧以查看您是否已经在跟踪对象,这样您就可以在相应的帧上快速启动对象跟踪器-转发它让它赶上。这个程序并不简单,您需要确保正确管理线程之间的数据。不过,这是一个很好的练习,可以让您熟悉多线程,这是编程中的一大步。

,

我想在this article的帮助下回答这个问题,我之前也用过,遇到了和你类似的问题。以下是建议:

  • 使用 darknet framework 运行 YOLOv3,这将提高性能。
  • 在您的代码片段中,它看起来不允许您决定网络输入的宽度和高度,所以我不知道您使用的是什么。减小网络宽度和高度会提高速度,但会降低准确度。
  • YOLOv3 针对 80 个对象进行了训练,但您只需要其中的一些。我之前也只需要我项目中的汽车。不幸的是,您无法操作已经训练好的权重文件,也无法很好地训练您的对象。
  • 我之前也尝试过的另一种方式是将 YOLOv3 转移到另一个线程,并且我也没有将 yolo 应用于所有帧。我只应用了其中的一些,例如:每 10 帧中的 1 帧。这对我也很有帮助。
  • 或者你可以选择更好的 CPU 电脑 :​​)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?