如何解决使用函数在tf数据集中的Tensorflow分割字符串张量产生错误
我是tensorflow的新手,学习使用tensorflow数据集。在尝试像下面这样将tf.TextLineDataset
的每个字符串张量分成两个张量时。
下面是一个最小的可复制示例。
期望
import tensorflow as tf
dataset = tf.data.TextLineDataset(<input_file>)
for x in dataset.take(2):
print(x)
# tf.Tensor(b'sample text\tsample label',shape=(),dtype=string)
for x in dataset.take(2):
a,b = tf.strings.split(x,sep='\t')
print(a)
print(b)
# tf.Tensor(b'sample text',dtype=string)
# tf.Tensor(b'sample label',dtype=string)
使用功能
def labeler(example):
'''
Splits each line into text and label sequence
'''
text_seq,label_seq = tf.strings.split(example,sep='\t')
return text_seq,label_seq
labeled_dataset = dataset.map(lambda x: labeler(x))
上面的代码返回以下错误
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
任何人都可以帮助解决上述错误。
更新
将函数定义更改为以下内容有助于解决错误。
def labeler(example):
'''
Splits each line into text and label sequence
'''
return tf.strings.split(example,sep='\t')
直觉上,这两种方法似乎都在做相同的事情。现在我对张量流如何不同地处理这两种方法感到困惑。
任何见解都会很有帮助。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。