微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

TFRecord encode_raw用于序列功能

如何解决TFRecord encode_raw用于序列功能

我有TFRecord格式的数据集:

def _bytes_feature(value):
    if isinstance(value,type(tf.constant(0))):
        value = value.numpy()

    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

sequence_dict = {
    'frames': tf.train.FeatureList(feature=frames),"label": tf.train.FeatureList(feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=[token])) for token in tokens]),}
context_dict = {
    "frames_count": tf.train.Feature(int64_list=tf.train.Int64List(value=[frames_count])),"num_tokens": num_tokens,}

sequence_context = tf.train.Features(feature=context_dict)
sequence_list = tf.train.FeatureLists(feature_list=sequence_dict)
example = tf.train.SequenceExample(context=sequence_context,feature_lists=sequence_list)
  • frames是一系列112x112灰度图像,由_bytes_feature函数的结果列表表示。
  • label是令牌序列。

我的任务是seq2seq,所以令牌的整个序列对应于帧的整个序列(len(frames) != len(label))。如果这更有意义,那么这项任务就是认真阅读。

我以这种方式加载数据集:

sequence_features = {
    'frames': tf.io.FixedLenSequenceFeature([],dtype=tf.string),"label": tf.io.FixedLenSequenceFeature([],dtype=tf.int64),}
context_features = {
    "frames_count": tf.io.FixedLenFeature([],"num_tokens":  tf.io.FixedLenFeature([],}
dataset = tf.data.TFRecordDataset("train-0.tfrecord")
dataset = dataset.map(_parse_function)
dataset = dataset.padded_batch(3)

问题是,我无法正确编写_parse_function,所以我可以遍历填充tf.int8张量序列的填充批次序列,这些序列代表一个视频的帧以及批次的相应的标签。我也想避免使用VarLenFeature,因为稀疏张量在CTC loss上的GPU上无法很好地发挥作用。

这是我尝试过的:

def _parse_function(example_proto):
    context,sequence,_ = tf.io.parse_sequence_example(example_proto,context_features=context_features,sequence_features=sequence_features)
    image = tf.io.decode_raw(sequence["frames"],tf.int8)
    label = sequence["label"]

    return image,label

抛出InvalidArgumentError: DecodeRaw requires input strings to all be the same size,but element 1 has size 2444 != 2456 [[{{node DecodeRaw}}]]

parse_sequence_example更改为parse_single_sequence_example并没有帮助,并引发相同的错误

所以问题是,我应该如何修改_parse_function以使其返回tf.int8帧序列的批次,形状为BxTxWxH,其中B是批次大小T是序列长度吗?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。