微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

如何在tensorflow 2中使用未知参数的ctc损失?

如何解决如何在tensorflow 2中使用未知参数的ctc损失?

我在TFRecord中有一个数据集,该数据集分为多个文件

我的记录如下:

sequence_features = {
    'frames': tf.io.FixedLenSequenceFeature([],dtype=tf.string),"label": tf.io.FixedLenSequenceFeature([],dtype=tf.int64),}

context_features = {
    "frames_count": tf.io.FixedLenFeature([],"num_tokens":  tf.io.FixedLenFeature([],}

frameslabel序列的长度不同。 帧的形状为BxTxHxWxC张量,其中B是批处理大小,T是批处理中最大的帧数,H是图像高度(112),{ {1}}是图片宽度(112),W是频道(1)

C的形状为label,其中BxS是批次中最长的标签

我正在S上建立我的模型,并像这样编译它:

tf.keras.Sequential()

问题model.compile( loss=tf.nn.ctc_loss(),optimizer=tf.keras.Adam(),) 需要4个参数-tf.nn.ctc_loss()labelslogitslabel_length。但是在模型编译时我还不知道,因为它在批次之间以及每个时期都不同。

在整个数据集中将帧计数或标签填充或裁剪为相同的长度并不是真正的选择,因为最长的序列可能比最短的序列长20倍。

我正在这样读取数据集:

logit_length

在这种情况下我该怎么办?或者如何更改代码以利用ctc损失?

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。