Keras交叉验证:输出广播错误

如何解决Keras交叉验证:输出广播错误

我在应用k倍层折时遇到问题。在下面,您可以找到运行良好的我的体系结构。它由两个信号输入组成,这些信号输入随后在特征提取之后与inputB,inputC和inputD连接:

def small_model2(): 
    
    ampl_signal = Input(shape=(X_train.shape[1:]))
    phase_signal = Input(shape=(X_train_phase.shape[1:]))
    inputA= Input(shape=(1,))
    inputB= Input(shape=(1,))
    inputC= Input(shape=(1,))
    
    concat_signal = concatenate([ampl_signal,phase_signal])
    #x = InputLayer(input_shape=(None,X_train.shape[1:][0],1))(inputA)
    x = Conv1D(64,5,activation='relu',kernel_initializer='glorot_normal')(concat_signal) #,input_shape=(None,3750,n_features)
    x = Conv1D(64,kernel_initializer='glorot_normal')(x)
   # x = Dropout(0.1)(x)
    x = MaxPooling1D(5)(x) 
   # x = Conv1D(128,kernel_initializer='glorot_normal')(x) #,n_features)
   # x = Conv1D(128,kernel_initializer='glorot_normal')(x)
   # x = Dropout(0.1)(x)
   # x = MaxPooling1D(5)(x)   
    x = Conv1D(64,activation='elu',kernel_initializer='glorot_normal')(x)
    x = Conv1D(64,kernel_initializer='glorot_normal')(x)
    #x = Dropout(0.1)(x)
    # x = Dropout(0.2)(x)
    #x = Flatten()(x)
    x = GlobalAveragePooling1D()(x)
    concatenated_features = concatenate([x,inputB,inputC,inputD])#inputD
    x = Dense(64,activation='relu')(concatenated_features)
    #
# Check for the position of the dropout
  #  x = Dropout(0.2)(x)
    x = Dense(n_outputs,activation='sigmoid')(x)

    model = Model(inputs=[ampl_signal,phase_signal,inputD],outputs=x)
    
    #optim = SGD(lr=lr,clipnorm=1.)
    optim = Adam(lr=lr)
    model.compile(loss='binary_crossentropy',optimizer=optim,metrics=['accuracy'])
    #print(model.summary())
    return model

然后,我尝试应用k倍交叉验证:

from sklearn.model_selection import KFold

acc_per_fold = []
loss_per_fold = []


num_folds = 4
kfold = KFold(n_splits=num_folds,shuffle=True)

fold_no = 1


#inputs = np.c_[X_train,X_train_phase,train_moisture,train_temp,train_weight]
#targets = y_train

inputs = np.concatenate(([X_train,train_weight],[X_test,X_val_phase,test_moisture,test_temp,test_weight]),axis=0)
targets = np.concatenate((y_train,y_test),axis=0)

for train,test in kfold.split(inputs,targets):
    model=small_model()
    #history=model.fit([X_train,y_train,#              validation_data = ([X_test,test_weight],#              #validation_data=([X_val,val_weight,val_moisture,val_temp],y_train[val]),#              epochs=100,batch_size=150)
    
    history = model.fit(inputs[train],targets[train],batch_size=batch_size,epochs=no_epochs,verbose=verbosity) 
    scores = model.evaluate([X_val,y_train[val],verbose=0)
#    
    print("%s: %.2f%%" % (model.metrics_names[0],scores[0]))
    print("%s: %.2f%%" % (model.metrics_names[1],scores[1]))
#    
  # Generate generalization metrics
    scores = model.evaluate(inputs[test],targets[test],verbose=0)
    print(f'score for fold {fold_no}: {model.metrics_names[0]} of {scores[0]}; {model.metrics_names[1]} of {scores[1]*100}%')
    acc_per_fold.append(scores[1] * 100)
    loss_per_fold.append(scores[0])    
    
    loss_history =np.array(history.history["loss"])
    accuracy_history= np.array(history.history["acc"])
    val_loss_history = np.array(history.history["val_loss"])
    val_accuracy_history=np.array(history.history["val_acc"])#

    np.savetxt("outputs/accuracy_loss_history_64000-90000_down_5_CNN_lr_"+str(lr)+"_fold_"+str(i)+".txt",np.c_[loss_history,accuracy_history,val_loss_history,val_accuracy_history],delimiter="\t")
    i+=1
    
    model.evaluate([X_test.reshape(len(X_test),5200,1),test_temp],y_test,verbose=0)
    
    fold_no = fold_no +1

但是我仍然遇到广播错误

ValueError:无法将形状为(1929,6249,1)的输入数组广播为形状(1929)

我知道我的输入似乎没有正确的形状。但是我不明白为什么。我输入与训练神经网络时相同的输入。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?