如何解决使用 tensorflow keras 预测 5 个不同类别的标签
我有以下问题,我有一个包含 3dprinter 数据的数据集,并且想要使用 tensorflow nn 预测表示错误的标签。 但是,该标签从 0 到 5 - 我怎么能做到这一点?我需要五种不同的输出吗?因为据我了解分类,它只分配标签。
找不到任何关于此的确切信息,也许是因为我不知道如何搜索它 - 在整个主题中都很新。
数据要么是单热编码的,要么是浮动的,我正在尝试使用 keras 调谐器来查找网络的超参数 - 我目前是这样的:
def build_model_hp(self,hp,model_type):
if model_type == 'standard':
shape = (59,)
elif model_type == 'expert':
shape = (73,)
else:
shape = (60,)
inputs = tf.keras.Input(shape=shape)
x = inputs
for i in range(hp.Int('hidden_blocks',3,10,default=3)):
x = tf.keras.layers.Dense(hp.Int('hidden_size_'+str(i),16,256,step=16,default=16),activation='relu')(x)
x = tf.keras.layers.Dropout(hp.Float('dropout',0.5,step=0.1,default=0.5))(x)
outputs = tf.keras.layers.Dense(1,activation='sigmoid')(x)
model = tf.keras.Model(inputs,outputs)
if (hp.Choice('optimizer',['adam','sgd'])) == 'adam':
opt = tf.keras.optimizers.Adam(
hp.Float('learning_rate',1e-4,1e-2,sampling='log'))
else:
opt = tf.keras.optimizers.SGD(
hp.Float('learning_rate',1e-2),nesterov=True)
model.compile(
optimizer=opt,loss='binary_crossentropy',metrics=['accuracy'])
return model
解决方法
如果您有 6 个标签为 0-5 的类,则将输出层从
outputs = tf.keras.layers.Dense(1,activation='sigmoid')(x)
输出 = tf.keras.layers.Dense(6,activation='softmax')(x)
change your model compile code from
model.compile(
optimizer=opt,loss='binary_crossentropy',metrics=['accuracy'])
如果您的标签是单热编码,则如下所示
model.compile(
optimizer=opt,loss='categorical_crossentropy',metrics=['accuracy'])
如果你的标签是整数,那么使用
model.compile(
optimizer=opt,loss='sparse_categorical_crossentropy',metrics=['accuracy'])
在您训练模型后(假设您对标签进行热编码并使用 loss='categorical_crossentropy),然后对您的测试集进行预测
from sklearn.metrics import confusion_matrix,classification_report
classes=test_gen.class_indices.keys()
labels=test_gen.labels
y_pred=[]
y_true=[]
preds=model.predict(test_gen)
for i,p in enumerate(preds)
y_pred=np.argmax(p)
y_true=labels[i] # assumes you have a list of labels for each test file
ypred=np.array(y_pred)
ytrue=np.array(y_true)
clr = classification_report(y_true,y_pred,target_names=classes) # assumes classes is a list of your classes
print("Classification Report:\n----------------------\n",clr)
我假设您有一个测试生成器来生成批量测试数据
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。