微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

TFRecord编码嵌套对象

如何解决TFRecord编码嵌套对象

我是Tensorflow的新手,我试图将一个大型数据集分解为TFRecords。我正在编码的格式如下:

  • ID(字符串,字节)
  • 索引(int64)
  • 时间(int64)
  • 图像(图像,字节)
  • 标签标签列表,字节)

Label对象具有FrameID(int64),Category(int64),x1(Float),x2(Float),y1(Float),y2(Float) 但是,我正在努力将这些信息序列化。我将标签列表分解为与对象的属性(即id [],category [] ...)相对应的列表。

当前,这是序列化单个元素的方式,从TFRecord的文档页面采用:

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value,type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _int64_list_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

这就是将数据写入tfrecords文件的方式。

def serialize_header(feature0,feature1,feature2,feature3,feature4,feature5,feature6,feature7,feature8,feature9):
    """
    Creates a tf.train.Example message ready to be written to a file.
    """
    # Create a dictionary mapping the feature name to the tf.train.Example-compatible data type.
    feature = {
        'id': _bytes_feature(feature0),'index': _int64_feature(feature1),'time': _int64_feature(feature2),'image': _bytes_feature(feature3),'frame_id': _int64_list_feature(feature4),'category': _int64_list_feature(feature5),'x1': _float_list_feature(feature6),'x2': _float_list_feature(feature7),'y1': _float_list_feature(feature8),'y2': _float_list_feature(feature9)
    }
    # Create a Features message using tf.train.Example.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializetoString()

with tf.io.TFRecordWriter('test.tfrecords') as writer:
   result = serialize_header(b'TestID',3,4,open("b1c66a42-6f7d68ca.jpg",'rb').read(),[3,4],[1,2],[2.2,3.3],[4.4,5.5],[6.6,7.7],[8.8,9.9])
   print(result)
   writer.write(result)

到目前为止,一切进展顺利。直到尝试从数据集中读取数据时,我才陷入错误

raw_dataset = tf.data.TFRecordDataset('test.tfrecords')

# Create a dictionary describing the features.
feature_description = {
    'id': tf.io.FixedLenFeature([],tf.string),'index': tf.io.FixedLenFeature([],tf.int64),'time': tf.io.FixedLenFeature([],'image': tf.io.FixedLenFeature([],'frame_id': tf.io.FixedLenFeature([],'category': tf.io.FixedLenFeature([],'x1': tf.io.FixedLenFeature([],tf.float32),'x2': tf.io.FixedLenFeature([],'y1': tf.io.FixedLenFeature([],'y2': tf.io.FixedLenFeature([],tf.float32)
}

def _parse_function(example_proto):
  # Parse the input tf.train.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto,feature_description)

parsed_dataset = raw_dataset.map(_parse_function)
print(parsed_dataset)

for image_features in parsed_dataset:
  image_raw = image_features['id'].numpy()
  display(Image(data=image_raw))

错误所在:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-32-c5d6610d5b7f> in <module>()
     49 print(parsed_dataset)
     50 
---> 51 for image_features in parsed_dataset:
     52   image_raw = image_features['id'].numpy()
     53   display(Image(data=image_raw))
InvalidArgumentError: Key: y2.  Can't parse serialized Example.
     [[{{node ParseSingleExample/ParseExample/ParseExampleV2}}]]

我无法确定我是否正确编码了数据,但解码错误,反之亦然,或两者皆有。拥有某人的专业知识将是很棒的。

解决方法

使用_int64_list_feature / _float_list_feature而非FixedLenFeature([],tf.int64/tf.float32)创建时,请尝试tf.io.VarLenFeature(tf.int64/tf.float32)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。