如何解决TFRecord编码嵌套对象
我是Tensorflow的新手,我试图将一个大型数据集分解为TFRecords。我正在编码的格式如下:
Label对象具有FrameID(int64),Category(int64),x1(Float),x2(Float),y1(Float),y2(Float) 但是,我正在努力将这些信息序列化。我将标签列表分解为与对象的属性(即id [],category [] ...)相对应的列表。
当前,这是序列化单个元素的方式,从TFRecord的文档页面采用:
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value,type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _float_list_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _int64_list_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
这就是将数据写入tfrecords文件的方式。
def serialize_header(feature0,feature1,feature2,feature3,feature4,feature5,feature6,feature7,feature8,feature9):
"""
Creates a tf.train.Example message ready to be written to a file.
"""
# Create a dictionary mapping the feature name to the tf.train.Example-compatible data type.
feature = {
'id': _bytes_feature(feature0),'index': _int64_feature(feature1),'time': _int64_feature(feature2),'image': _bytes_feature(feature3),'frame_id': _int64_list_feature(feature4),'category': _int64_list_feature(feature5),'x1': _float_list_feature(feature6),'x2': _float_list_feature(feature7),'y1': _float_list_feature(feature8),'y2': _float_list_feature(feature9)
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializetoString()
with tf.io.TFRecordWriter('test.tfrecords') as writer:
result = serialize_header(b'TestID',3,4,open("b1c66a42-6f7d68ca.jpg",'rb').read(),[3,4],[1,2],[2.2,3.3],[4.4,5.5],[6.6,7.7],[8.8,9.9])
print(result)
writer.write(result)
到目前为止,一切进展顺利。直到尝试从数据集中读取数据时,我才陷入错误。
raw_dataset = tf.data.TFRecordDataset('test.tfrecords')
# Create a dictionary describing the features.
feature_description = {
'id': tf.io.FixedLenFeature([],tf.string),'index': tf.io.FixedLenFeature([],tf.int64),'time': tf.io.FixedLenFeature([],'image': tf.io.FixedLenFeature([],'frame_id': tf.io.FixedLenFeature([],'category': tf.io.FixedLenFeature([],'x1': tf.io.FixedLenFeature([],tf.float32),'x2': tf.io.FixedLenFeature([],'y1': tf.io.FixedLenFeature([],'y2': tf.io.FixedLenFeature([],tf.float32)
}
def _parse_function(example_proto):
# Parse the input tf.train.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto,feature_description)
parsed_dataset = raw_dataset.map(_parse_function)
print(parsed_dataset)
for image_features in parsed_dataset:
image_raw = image_features['id'].numpy()
display(Image(data=image_raw))
错误所在:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-32-c5d6610d5b7f> in <module>()
49 print(parsed_dataset)
50
---> 51 for image_features in parsed_dataset:
52 image_raw = image_features['id'].numpy()
53 display(Image(data=image_raw))
InvalidArgumentError: Key: y2. Can't parse serialized Example.
[[{{node ParseSingleExample/ParseExample/ParseExampleV2}}]]
我无法确定我是否正确编码了数据,但解码错误,反之亦然,或两者皆有。拥有某人的专业知识将是很棒的。
解决方法
使用_int64_list_feature
/ _float_list_feature
而非FixedLenFeature([],tf.int64/tf.float32)
创建时,请尝试tf.io.VarLenFeature(tf.int64/tf.float32)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。