微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

TFRecords 解析:如何从单个张量中检索多个图像?

如何解决TFRecords 解析:如何从单个张量中检索多个图像?

我正在尝试解析 coco 格式的数据集,其中包括(除其他外)输入图像和作为输出的图像列表(掩码)。 数据集已使用 efficientdet/dataset_tools/create_coco_record.py

转换为 tfrecords

以下是序列化的片段:

feature_dict = {
      'image/height':
          tfrecord_util.int64_feature(image_height),'image/width':
          tfrecord_util.int64_feature(image_width)
      'image/encoded':
          tfrecord_util.bytes_feature(encoded_jpg),}
...
for object_annotations in bBox_annotations:
    run_len_encoding = mask.frPyObjects(object_annotations['segmentation'],image_height,image_width)
    binary_mask = mask.decode(run_len_encoding)
    binary_mask = np.amax(binary_mask,axis=2)
    pil_image = PIL.Image.fromarray(binary_mask)
    output_io = io.BytesIO()
    pil_image.save(output_io,format='PNG')
    encoded_mask_png.append(output_io.getvalue()

if include_masks:
    feature_dict['image/object/mask'] = (
        tfrecord_util.bytes_list_feature(encoded_mask_png))

我的问题与 tfrecords 的解码有关,我无法解码掩码张量中的图像。

以下是我的解析函数

def parse_example(serialized_example):
  feature_dict = {
    'image/height': tf.io.FixedLenFeature([],tf.int64),'image/width': tf.io.FixedLenFeature([],'image/encoded':  tf.io.FixedLenFeature([],tf.string),'image/object/class/label': tf.io.FixedLenSequenceFeature([],tf.int64,allow_missing=True),'image/object/mask': tf.io.FixedLenSequenceFeature([],tf.string,}

  example = tf.io.parse_single_example(serialized_example,features=feature_dict)

  raw_height = tf.cast(example['image/height'],tf.int64)
  raw_width = tf.cast(example['image/width'],tf.int64)
  image = tf.image.decode_png(example['image/encoded'],channels=3)
  image = tf.image.resize(image,(512,512))
  labels = example['image/object/class/label'] 

  masks = tf.image.decode_png(example['image/object/mask'],channels=3)

我收到的错误

ValueError: Shape must be rank 0 but is rank 1 for '{{node DecodePng_1}} = DecodePngchannels=3,dtype=DT_UINT8' 输入形状:[?]。

我将如何解码向量中的多个图像?

解决方法

tf_example_decoder.py 找到解决方案。

以下是一些代码片段:

将图像读取为字符串类型的 VarLenFeatures

keys_to_features = {
...
'image/object/mask': tf.io.VarLenFeature(tf.string)
}

parsed_tensors = tf.io.parse_single_example(
        serialized=serialized_example,features=keys_to_features)

将稀疏张量转换为密集张量

for k in parsed_tensors:
  if isinstance(parsed_tensors[k],tf.SparseTensor):
    if parsed_tensors[k].dtype == tf.string:
      parsed_tensors[k] = tf.sparse.to_dense(
          parsed_tensors[k],default_value='')

然后使用以下方法解码掩码:

def _decode_masks(self,parsed_tensors):
    """Decode a set of PNG masks to the tf.float32 tensors."""
    def _decode_png_mask(png_bytes):
      mask = tf.squeeze(
          tf.io.decode_png(png_bytes,channels=1,dtype=tf.uint8),axis=-1)
      mask = tf.cast(mask,dtype=tf.float32)
      mask.set_shape([None,None])
      return mask

    height = parsed_tensors['image/height']
    width = parsed_tensors['image/width']
    masks = parsed_tensors['image/object/mask']
    return tf.cond(
        pred=tf.greater(tf.size(input=masks),0),true_fn=lambda: tf.map_fn(_decode_png_mask,masks,dtype=tf.float32),false_fn=lambda: tf.zeros([0,height,width],dtype=tf.float32))

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。