如何解决TFRecord encode_raw用于序列功能
我有TFRecord格式的数据集:
def _bytes_feature(value):
if isinstance(value,type(tf.constant(0))):
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
sequence_dict = {
'frames': tf.train.FeatureList(feature=frames),"label": tf.train.FeatureList(feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=[token])) for token in tokens]),}
context_dict = {
"frames_count": tf.train.Feature(int64_list=tf.train.Int64List(value=[frames_count])),"num_tokens": num_tokens,}
sequence_context = tf.train.Features(feature=context_dict)
sequence_list = tf.train.FeatureLists(feature_list=sequence_dict)
example = tf.train.SequenceExample(context=sequence_context,feature_lists=sequence_list)
-
frames
是一系列112x112
灰度图像,由_bytes_feature
函数的结果列表表示。 -
label
是令牌序列。
我的任务是seq2seq
,所以令牌的整个序列对应于帧的整个序列(len(frames) != len(label)
)。如果这更有意义,那么这项任务就是认真阅读。
我以这种方式加载数据集:
sequence_features = {
'frames': tf.io.FixedLenSequenceFeature([],dtype=tf.string),"label": tf.io.FixedLenSequenceFeature([],dtype=tf.int64),}
context_features = {
"frames_count": tf.io.FixedLenFeature([],"num_tokens": tf.io.FixedLenFeature([],}
dataset = tf.data.TFRecordDataset("train-0.tfrecord")
dataset = dataset.map(_parse_function)
dataset = dataset.padded_batch(3)
问题是,我无法正确编写_parse_function
,所以我可以遍历填充tf.int8
张量序列的填充批次序列,这些序列代表一个视频的帧以及批次的相应的标签。我也想避免使用VarLenFeature
,因为稀疏张量在CTC loss上的GPU上无法很好地发挥作用。
这是我尝试过的:
def _parse_function(example_proto):
context,sequence,_ = tf.io.parse_sequence_example(example_proto,context_features=context_features,sequence_features=sequence_features)
image = tf.io.decode_raw(sequence["frames"],tf.int8)
label = sequence["label"]
return image,label
抛出InvalidArgumentError: DecodeRaw requires input strings to all be the same size,but element 1 has size 2444 != 2456 [[{{node DecodeRaw}}]]
将parse_sequence_example
更改为parse_single_sequence_example
并没有帮助,并引发相同的错误
所以问题是,我应该如何修改_parse_function
以使其返回tf.int8
帧序列的批次,形状为BxTxWxH
,其中B是批次大小T是序列长度吗?
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。