如何解决graph_def.ParseFromString(f.read()) google.protobuf.message.DecodeError:解析消息时出错
我试图使用我训练有素的 ssd_resnet50_v1_fpn saved_model.pb
来推断视频,这是由 TensorFlow Object Detection API 2.0 版训练的,我的 TF 版本是 2.4.0。我使用附加的脚本来运行推理,但出现了令人头疼的错误消息:
File "ObjDet_Test_Video.py",line 35,in <module>
od_graph_def.ParseFromString(serialized_graph)
google.protobuf.message.DecodeError: Error parsing message
下面是我的推理脚本。有大佬知道怎么解决吗?
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import cv2
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import datetime as dt
from objectTracker import *
from collections import deque
# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import ops as utils_ops
from utils import label_map_util
from utils import visualization_utils as vis_util
PATH_TO_CKPT = 'exported-models/my_model/saved_model/saved_model.pb'
PATH_TO_LABELS = os.path.join('annotations','label_map.pbtxt')
NUM_CLASSES = 17
# Load Tensorflow model to memory
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.compat.v1.GraphDef()
with tf.io.gfile.GFile(PATH_TO_CKPT,'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def,name='')
# Loading label map
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map,max_num_classes=NUM_CLASSES,use_display_name=True)
category_index = label_map_util.create_category_index(categories)
# Helper code
def load_image_into_numpy_array(image):
(im_width,im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height,im_width,3)).astype(np.uint8)
def cvDrawBoundingBox(xLeft,yTop,width,height,bgr,img):
topLeft = (xLeft,yTop)
botRight = (xLeft+width,yTop+height)
#bgr = (0,255)
thickness = 2
img = cv2.rectangle(img,topLeft,botRight,thickness)
return img
def cvDrawText(txt,x,y,img):
botLeft_xy = (x,y)
font = cv2.FONT_HERSHEY_SIMPLEX
font_size = 0.66
color = bgr
thickness = 1
lineType = cv2.LINE_AA
cv2.putText(img,txt,botLeft_xy,font,font_size,color,thickness,lineType)
return img
def cvDrawTrackBlobs(trackList,img):
for track in trackList:
boxColor = (255,255,255) #white
if (track.confirmed == True):
boxColor = (0,255) #red
x = track.blob.xLeft
y = track.blob.yTop
w = track.blob.width
h = track.blob.height
img = cvDrawBoundingBox(x,w,h,boxColor,img)
blobClass = " "
txt = str(track.id)+blobClass
img = cvDrawText(txt,x-2,y-5,img)
return img
# Detection
IMAGE_SIZE = (12,8)
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
# Definite input and output Tensors for detection_graph
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
# Each box represents a part of the image where a particular object was detected.
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
# Each score represent how level of confidence for each of the objects.
# Score is shown on the result image,together with the class label.
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
cap = cv2.VideoCapture('N_Oakland.avi')
delay_ms = -1
#plt.figure(figsize=IMAGE_SIZE)
#Dictionary for tracks indexed by Id and their state
trackDict = {id: False for id in range(1,251)}
trackedObjList = []
trackTraceQueue = deque([],30)
frameNum = 0
while(cap.isOpened()):
numBlobs = 0
blobList = []
ret,frame = cap.read()
blobFrame = frame.copy()
im_height,im_depth = frame.shape
# Convert open cv color image format (BGR) to RGB as expected by Tensorflow
image_np = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
# Expand dimensions since the model expects images to have shape: [1,None,3]
image_np_expanded = np.expand_dims(image_np,axis=0)
time_start = dt.datetime.now()
# Actual detection.
(boxes,scores,classes,num) = sess.run(
[detection_boxes,detection_scores,detection_classes,num_detections],feed_dict={image_tensor: image_np_expanded})
time_end = dt.datetime.now()
elapsed_time = time_end - time_start
# Object Tracking
numObjs = int(num[0])
print()
print("FrameNo:",frameNum)
print("Objects:",numObjs)
for idx in range(numObjs):
if (scores[0][idx] < 0.5):
continue
ymin,xmin,ymax,xmax = boxes[0][idx]
xmin = int(xmin*im_width)
xmax = int(xmax*im_width)
ymin = int(ymin*im_height)
ymax = int(ymax*im_height)
objW = xmax-xmin
objH = ymax-ymin
#print(idx,ymin,objW,objH,(xmax+xmin)/2,(ymax+ymin)/2,objW*objH)
#print()
if (ymax > 0 and objW < im_width/2 and objH < im_height/2):
blobList.append(blobAttr(numBlobs,idx,objW*objH))
numBlobs += 1
#print(idx,objW*objH,scores[0][idx])
#print(xmin,xmax,ymax)
"""Track moving objects"""
blobList,trackedObjList,trackDict = trackBlobs(blobList,trackDict)
blobFrame = cvDrawTrackBlobs(trackedObjList,blobFrame)
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),category_index,use_normalized_coordinates=True,line_thickness=2)
#plt.imshow(image_np)
opencv_image=cv2.cvtColor(image_np,cv2.COLOR_RGB2BGR)
txt = 'Inference Time (secs): '+ str(elapsed_time.total_seconds())
opencv_image = cvDrawText(txt,10,20,(0,255),opencv_image)
txt = 'Model: '+ MODEL_NAME[:30]
opencv_image = cvDrawText(txt,450,opencv_image)
txt = 'Frame No: '+ str(frameNum)
blobFrame = cvDrawText(txt,30,0),blobFrame)
cv2.imshow('Objects',opencv_image)
cv2.imshow('Tracks',blobFrame)
frameNum = frameNum + 1
k = cv2.waitKey(delay_ms) & 0xff
if (k == 13): #enter
if (delay_ms < 0):
delay_ms = 30
else:
delay_ms = -1
if (k == 32): #space
continue
elif (k == 27): #Esc
break
cap.release()
cv2.destroyAllWindows()
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。