微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

使用迭代器提取张量值

如何解决使用迭代器提取张量值


我正在尝试为数据集的每个图像提取标签值。我的想法是拆分训练和测试:

data,info= tfds.load("cats_vs_dogs",as_supervised = True,split = 'train',with_info=True)

# Split into training and test parts
test_dataset = data.take(5500)
train_dataset = data.skip(5500)

然后我想应用预处理来根据训练中的标签进行区分:

def preprocess_start(image,label):
  #cast the image values from integers to floating and then divide by 255
  image = tf.image.resize(image,[100,100])
  image = tf.cast(image,tf.float32)/ 255.0
  if label == 0:   #this check is not correct
    
    code ...
  else:
    code ...
  return image,label

train_test= train_test.map(preprocess_start) 

但是主要的问题是标签是:

Tensor("args_1:0",shape=(),dtype=int64)

如何提取整数值?

解决方法

这是基于AutoGraph Transformations

# Split into training and test parts
test_dataset = data.take(5500)
train_dataset = data.skip(5500)

@tf.function
def preprocess_start(image,label):
    # cast the image values from integers to floating and then divide by 255
    image = tf.image.resize(image,[100,100])
    image = tf.cast(image,tf.float32) / 255.0
    if tf.equal(label,0 ):  # this check is not correct
        tf.print("Zero",label)
    else:
        tf.print("Not Zero",label)
    return image,label


train_test = train_dataset.map(preprocess_start)
for e in train_test:
    tf.print(e)

所以对于更简单的功能,这应该可以工作。如果您需要更复杂的逻辑,您必须阅读 rules

Zero 0
Zero 0
Not Zero 1
Zero 0
Zero 0
Zero 0
Not Zero 1

因此,例如,此检查应按预期工作。

l = tf.constant(0)
if tf.equal(l,0 ):  # this check is not correct
    tf.print("Zero",l)
else:
    tf.print("Not Zero",l)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。