微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

多输入/多输出:使用 KerasClassifier 和 GridSearchCV 时输出维度错误

如何解决多输入/多输出:使用 KerasClassifier 和 GridSearchCV 时输出维度错误

我已经使用 keras 和 tensorflow 构建了一个多输入(100 个特征)多输出(100 个预测)ANN 模型。我已经能够训练我的模型并在测试集使用以下代码

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K 
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout

def my_loss_fn(y_true,y_pred) : 
   d = K.sum(K.abs(y_true),axis = -1)
   n = K.sum((K.tanh(100000*y_true*y_pred)/2 + 0.5)*K.abs(y_true),axis = -1)
   return 1 - n/d

def my_metric_fn(y_true,y_pred) : 
   d = K.sum(K.abs(y_true))
   n = K.sum((K.tanh(100000*y_true*y_pred)/2 + 0.5)*K.abs(y_true))
   return n/d

def accuracy(y_true,y_pred) : 
   #print(y_true.shape,y_true)
   #print(y_pred.shape,y_true)
   acc = np.zeros([1,len(y_true)])
   for day in range(len(y_pred)) :
       d = 0
       n = 0
       for i in range(len(y_pred[0])) :
           d = d + abs(y_true[day,i])
           if np.sign(y_pred[day,i])*np.sign(y_true[day,i]) > 0 : 
               n = n + abs(y_true[day,i])
           else : 
               n = n + 0
       acc[0,day] = n/d
   return np.mean(acc,axis = -1)[0]

#Model
classifier = Sequential()
classifier.add(Dense(units = 50,input_shape = (100,),activation = "tanh"))
classifier.add(Dropout(0.2))
classifier.add(Dense(units=100,activation = 'tanh'))
classifier.compile(optimizer = 'rmsprop',loss = my_loss_fn,metrics = ['accuracy',my_metric_fn]) 

#Training
callback = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss',min_delta = 0.0001,patience = 20,verbose = 0,mode = 'min')
nb_epochs = 250
history = classifier.fit(X_train,y_train,epochs = nb_epochs,batch_size = 31,callbacks = [callback],verbose = True,validation_split = 0.,validation_data = (X_test,y_test),use_multiprocessing = True)

#Prediction
y_pred_train = classifier.predict(X_train)
y_pred_test = classifier.predict(X_test)
acc_test = accuracy(y_test,y_pred_test)
acc_train = accuracy(y_train,y_pred_train)

我试图通过调整超参数来提高模型的性能,因此我使用了 KerasClassifier()gridsearchcv()。以下代码说明了我的网格搜索方法

from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import gridsearchcv
from sklearn.metrics import make_scorer
from tensorflow import autograph

#Building a function to create the classifier
def build_classifier(nb_layers,nb_nodes,optimizer,dropout,activation_fn):
    classifier=Sequential()
    classifier.add(Dense(units = nb_nodes,activation = activation_fn))
    for i in range(nb_layers-1) : 
        classifier.add(Dense(units = nb_nodes,activation = activation_fn,kernel_initializer = "uniform"))
        classifier.add(Dropout(dropout))
    classifier.add(Dense(units = 100,activation = 'tanh'))
    classifier.compile(optimizer=optimizer,loss = tf.autograph.experimental.do_not_convert(my_loss_fn),metrics= ['accuracy',tf.autograph.experimental.do_not_convert(my_metric_fn)])
    return classifier

#Creating a scorer to Feed to the gridsearchcv()
my_scorer = make_scorer(accuracy,greater_is_better = True)
classifier=KerasClassifier(build_fn=build_classifier)
parameters={'batch_size':[13,31],'epochs':[100,150],'optimizer':['adam','rmsprop'],'dropout' : [0.2,0.1],'nb_layers' : [2,3],'nb_nodes' : [45,50,110,115],'activation_fn' : ['relu','tanh']} 
grid_search=gridsearchcv(estimator=classifier,scoring = my_scorer,param_grid=parameters,cv=5,verbose = 1) 
grid_search=grid_search.fit(X_train_,y_train_raw)

当我拟合我的 gridsearchcv() 对象时,我在第一个超参数组合结束时收到以下错误(计算评分时):

TypeError: object of type 'numpy.int32' has no len()

我通过在 accuracy() 函数添加打印命令进行了调查

#print(y_true.shape,y_true)
#print(y_pred.shape,y_pred)

打印形状和数组 y_truey_pred 作为我的 accuracy() 函数的输入,用作 scoring 对象中的 gridsearchcv()

我发现 y_true.shape == (555,100)y_pred.shape == (555,)。值 555 对应于第五个验证集的行数,因为 cv = 5

然而,我不明白为什么即使分类器最后一层的节点数是gridsearch(100,)的预测也不是多输出预测。

解决方法

这是一个回归问题,所以我改用 KerasRegressor() 并解决了这个问题。我猜对于一个多输出分类问题,KerasClassifier() 期望输出是一个二维热编码数组。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。