如何解决使用迭代器提取张量值
我正在尝试为数据集的每个图像提取标签值。我的想法是拆分训练和测试:
data,info= tfds.load("cats_vs_dogs",as_supervised = True,split = 'train',with_info=True)
# Split into training and test parts
test_dataset = data.take(5500)
train_dataset = data.skip(5500)
然后我想应用预处理来根据训练中的标签进行区分:
def preprocess_start(image,label):
#cast the image values from integers to floating and then divide by 255
image = tf.image.resize(image,[100,100])
image = tf.cast(image,tf.float32)/ 255.0
if label == 0: #this check is not correct
code ...
else:
code ...
return image,label
train_test= train_test.map(preprocess_start)
但是主要的问题是标签是:
Tensor("args_1:0",shape=(),dtype=int64)
如何提取整数值?
解决方法
# Split into training and test parts
test_dataset = data.take(5500)
train_dataset = data.skip(5500)
@tf.function
def preprocess_start(image,label):
# cast the image values from integers to floating and then divide by 255
image = tf.image.resize(image,[100,100])
image = tf.cast(image,tf.float32) / 255.0
if tf.equal(label,0 ): # this check is not correct
tf.print("Zero",label)
else:
tf.print("Not Zero",label)
return image,label
train_test = train_dataset.map(preprocess_start)
for e in train_test:
tf.print(e)
所以对于更简单的功能,这应该可以工作。如果您需要更复杂的逻辑,您必须阅读 rules
Zero 0
Zero 0
Not Zero 1
Zero 0
Zero 0
Zero 0
Not Zero 1
因此,例如,此检查应按预期工作。
l = tf.constant(0)
if tf.equal(l,0 ): # this check is not correct
tf.print("Zero",l)
else:
tf.print("Not Zero",l)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。