如何解决TypeError:传递给参数'input'的值的DataType bool不在允许值列表中:float32,float64,int32,uint8,int16,int8
我有一个包含5个标签的数据集
def get_label(file_path):
# convert the path to a list of path components
parts = tf.strings.split(file_path,os.path.sep)
class_names = ['daisy' 'dandelion' 'roses' 'sunflowers' 'tulips']
# The second to last is the class-directory
one_hot = parts[-2] == class_names
# Integer encode the label
return tf.argmax(one_hot)
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img,channels=3)
# resize the image to the desired size
return tf.image.resize(img,[img_height,img_width])
def process_path(file_path):
label = get_label(file_path)
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = decode_img(img)
return img,label
train_ds = train_ds.map(process_path,num_parallel_calls=AUTOTUNE)
如果我使用具有2个标签的其他数据集来更改此代码,则class_names = ['dog','cat']
会发现此错误
TypeError: Value passed to parameter 'input' has DataType bool not in list of allowed values: float32,float64,int32,uint8,int16,int8,complex64,int64,qint8,quint8,qint32,bfloat16,uint16,complex128,float16,uint32,uint64
那么我如何更新 def get_label(file_path)
解决方法
我的猜测是tf.argmax需要这些数据类型之一(我现在无法测试)
float32,float64,int32,uint8,int16,int8,complex64,int64,qint8,quint8,qint32,bfloat16,uint16,complex128,float16,uint32,uint64
所以您要做的就是转换输出
one_hot = parts[-2] == class_names
对于int,“ ==“的计算结果为True / False,这可能是不允许的。
,我遇到了同样的问题。遵循最后一篇文章的想法:
from django.contrib.auth import logout
from django.http import HttpResponse
def logout_view(request):
logout(request)
return HttpResponse('OK')
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。