Tensorflow TFWriter 不正确的数据序列化

如何解决Tensorflow TFWriter 不正确的数据序列化

我有一个使用 MatLab 的 ImageLabeller 创建的数据集,当尝试将数据集转换为 TFrecord 时,根据 here 中的说明,某些坐标不正确,看起来好像最小值大于最大限度。 我尝试删除失败的示例,但似乎错误与此无关,失败的示例总是出现在相同的位置。我尝试使用来自 MODD2 的图像和使用较大图像的 imageLabeller 创建的数据集,并且它可以正常工作。

用于生成 TFrecord 文件的代码如下:

# MODD2 format: x y w h -> x,y are the top left corner coordinates
def read_drone_mat_file(file_number):

    # navigate to the modd2 directory
    bbox_d = []
    bbox_o = []
    filename = []

    # for each file,load it into data and append the obstacles information into the bbox list
    mat = os.listdir(drones_dir)[file_number]
    frame = os.path.join(drones_dir,mat)
    data = sio.loadmat(frame)

    for obj in data['drone']:
        bbox_d.append(obj)

    for obj in data['obstacles']:
        bbox_o.append(obj)
    filename.append(mat[0:9])

    return bbox_d,bbox_o,filename

# %% Helper function to create a tfexample for the drone data
def create_drone_tfexample(drones,obstacles,index,image_path):

    image_format = b'jpg'
    filename = os.listdir(image_path)[index+2]

    # load corresponding image (only use left images)
    with tf.io.gfile.GFile(os.path.join(image_path,filename),'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    width,height = image.size

    wsize,hsize = (width,height)
    #basewidth = 640
    # if width > basewidth:
    #     wpercent = (basewidth/float(image.size[0]))
    #     hsize = int((float(image.size[1])*float(wpercent)))
    #     wsize = basewidth
    #     image = image.resize((basewidth,hsize),Image.ANTIALIAS)
    #     buffered = io.BytesIO()
    #     image.save(buffered,format="JPEG")
    #     encoded_jpg = buffered.getvalue()

    filename = os.path.splitext(filename)[0].encode('utf-8')
    create_drone_tfexample.source_id += 1
    source_id_s = "{}".format(create_drone_tfexample.source_id).encode('utf-8')

    # tfrecord features definition
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []

    # for each image
    for obj in drones:
        xmins.append(obj[0] / width)
        xmaxs.append((obj[0]+obj[2]) / width)
        ymins.append(obj[1] / height)
        ymaxs.append((obj[1]+obj[3]) / height)
        # until the drone dataset is available,all obstacles are class 0
        classes_text.append(bytes('drone','utf-8'))
        classes.append(2)

    for obj in obstacles:
        xmins.append(obj[0] / width)
        xmaxs.append((obj[0]+obj[2]) / width)
        ymins.append(obj[1] / height)
        ymaxs.append((obj[1]+obj[3]) / height)
        # until the drone dataset is available,all obstacles are class 0
        classes_text.append(bytes('obstacles','utf-8'))
        classes.append(1)

    print(source_id_s+b": "+filename)
    # print("xmins: {}".format(xmins))
    # print("xmaxs: {}".format(xmaxs))
    # print("ymins: {}".format(ymins))
    # print("ymaxs: {}".format(ymaxs))

    # create tf_example
    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(hsize),'image/width': dataset_util.int64_feature(wsize),'image/filename': dataset_util.bytes_feature(filename),'image/source_id': dataset_util.bytes_feature(source_id_s),'image/encoded': dataset_util.bytes_feature(encoded_jpg),'image/format': dataset_util.bytes_feature(image_format),'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),'image/object/class/text': dataset_util.bytes_list_feature(classes_text),'image/object/class/label': dataset_util.int64_list_feature(classes),}))

    return tf_example


create_drone_tfexample.source_id = 0

# %% Create final dataset WARNING: Slow and destructive
train_writer = tf.io.TFRecordWriter(
    output_dir+'drone_train_truncated.tfrecord')
test_writer = tf.io.TFRecordWriter(output_dir+'drone_test_truncated.tfrecord')
drone_test_writer = tf.io.TFRecordWriter(
    output_dir + 'drone_only_test.tfrecord')
create_drone_tfexample.source_id = 0

# Drones dataset
for index,mat in enumerate(os.listdir(drones_dir)):
    boxes_d,boxes_o,filename = read_drone_mat_file(index)
    print()
    # Pass the bounding boxes to the create_tfexample function
    if index < 210:
        image_path = drones_image_root
        tf_example = create_drone_tfexample(
            boxes_d,image_path)

    # Write the tf_example into the dataset
    if random.randint(1,100) <= 80:  # 80% Train  20% Validation
        train_writer.write(tf_example.SerializeToString())
    else:
        test_writer.write(tf_example.SerializeToString())
        drone_test_writer.write(tf_example.SerializeToString())

示例在尝试使用它们进行训练时失败,为了阅读示例,我使用以下代码:

# %% Extract images from dataset
dataset_file = "drone_only_test.tfrecord"
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
raw_dataset = tf.data.TFRecordDataset(
    "<path_to_dataset>"+dataset_file)

print('_______________________________________________________________________________________')
image_feature_description = {
    #             'image/height': dataset_util.int64_feature(hsize),#             'image/width': dataset_util.int64_feature(512),#             'image/filename': dataset_util.bytes_feature(filename),#             'image/source_id': dataset_util.bytes_feature(filename),#             'image/encoded': dataset_util.bytes_feature(encoded_jpg),#             'image/format': dataset_util.bytes_feature(image_format),#             'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),#             'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),#             'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),#             'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),#             'image/object/class/text': dataset_util.bytes_list_feature(classes_text),#             'image/object/class/label': dataset_util.int64_list_feature(classes),'image/height': tf.io.FixedLenFeature([],tf.int64),'image/width': tf.io.FixedLenFeature([],'image/filename': tf.io.FixedLenFeature([],tf.string),'image/source_id': tf.io.FixedLenFeature([],'image/encoded': tf.io.FixedLenFeature([],'image/format': tf.io.FixedLenFeature([],'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),'image/object/class/text': tf.io.VarLenFeature(tf.string),'image/object/class/label': tf.io.VarLenFeature(tf.int64),}


def _parse_image_function(example_proto):
    # Parse the input tf.train.Example proto using the dictionary above.
    return tf.io.parse_single_example(example_proto,image_feature_description)


parsed_image_dataset = raw_dataset.map(_parse_image_function)

for image_features in parsed_image_dataset.take(10):
    image_raw = image_features['image/encoded'].numpy()
    display.display(display.Image(data=image_raw))
    encoded_jpg_io = io.BytesIO(image_raw)
    image = Image.open(encoded_jpg_io)
    image.save("out.jpg",format="JPEG")
    print(f'ID: {image_features["image/filename"]}')
    print(f'XMIN: {image_features["image/object/bbox/xmin"].values*640}')
    print(f'XMAX: {image_features["image/object/bbox/xmax"].values*640}')
    print(f'YMIN: {image_features["image/object/bbox/ymin"].values*480}')
    print(f'YMAX: {image_features["image/object/bbox/ymax"].values*480}')
    print('---------------------')
    print(
        f'WIDTH: {image_features["image/object/bbox/xmax"].values*640 - image_features["image/object/bbox/xmin"].values*640}')
    print(
        f'HEIGHT: {image_features["image/object/bbox/ymax"].values*480 - image_features["image/object/bbox/ymin"].values*480}')

对于第四个位置的例子,输出如下:

ID: b'color_00000036'
XMIN: [179. 175.   5.]
XMAX: [387. 210.  21.]
YMIN: [263. 193. 242.]
YMAX: [372.   6. 248.]
---------------------
WIDTH: [208.  35.  16.]
HEIGHT: [ 109. -187.    6.]

相同图像的matlab输出如下:

ground_truth =

   179   175     5
   263   193   242
   208    35    16
   109    69     6

使用的版本如下:

  • Windows 10 64 位
  • Python 3.7.9 64 位
  • TensorFlow 2.4.0
  • Scipy 1.5.4
  • Tensorflow 对象检测 API 大师

解决方法

问题最终与 scipy.io.loadmat() 将数据转换为 np.uint8 相关,解决方案是将 mat_dtype=True 作为参数传递,以便将所有内容加载为 np.float64。 不是最有效的方法,但它有效。

非常感谢。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-
参考1 参考2 解决方案 # 点击安装源 协议选择 http:// 路径填写 mirrors.aliyun.com/centos/8.3.2011/BaseOS/x86_64/os URL类型 软件库URL 其他路径 # 版本 7 mirrors.aliyun.com/centos/7/os/x86
报错1 [root@slave1 data_mocker]# kafka-console-consumer.sh --bootstrap-server slave1:9092 --topic topic_db [2023-12-19 18:31:12,770] WARN [Consumer clie
错误1 # 重写数据 hive (edu)&gt; insert overwrite table dwd_trade_cart_add_inc &gt; select data.id, &gt; data.user_id, &gt; data.course_id, &gt; date_format(
错误1 hive (edu)&gt; insert into huanhuan values(1,&#39;haoge&#39;); Query ID = root_20240110071417_fe1517ad-3607-41f4-bdcf-d00b98ac443e Total jobs = 1
报错1:执行到如下就不执行了,没有显示Successfully registered new MBean. [root@slave1 bin]# /usr/local/software/flume-1.9.0/bin/flume-ng agent -n a1 -c /usr/local/softwa
虚拟及没有启动任何服务器查看jps会显示jps,如果没有显示任何东西 [root@slave2 ~]# jps 9647 Jps 解决方案 # 进入/tmp查看 [root@slave1 dfs]# cd /tmp [root@slave1 tmp]# ll 总用量 48 drwxr-xr-x. 2
报错1 hive&gt; show databases; OK Failed with exception java.io.IOException:java.lang.RuntimeException: Error in configuring object Time taken: 0.474 se
报错1 [root@localhost ~]# vim -bash: vim: 未找到命令 安装vim yum -y install vim* # 查看是否安装成功 [root@hadoop01 hadoop]# rpm -qa |grep vim vim-X11-7.4.629-8.el7_9.x
修改hadoop配置 vi /usr/local/software/hadoop-2.9.2/etc/hadoop/yarn-site.xml # 添加如下 &lt;configuration&gt; &lt;property&gt; &lt;name&gt;yarn.nodemanager.res