在 3 个合并的深度神经网络模型中交叉验证而不是训练和测试

如何解决在 3 个合并的深度神经网络模型中交叉验证而不是训练和测试

如何在这个深度神经网络代码中使用交叉验证方法而不是 train_test split 实际上,我正在合并3个深度神经网络模型......首先我合并了3个深度神经网络模型,然后进行分类。 .由于准确性问题,我想使用交叉验证方法而不是训练测试拆分,请指导我如何在此深度神经网络代码中使用交叉验证方法而不是 train_ 测试拆分...

这是我的代码,如果有人可以将此代码训练测试更改为交叉验证,那么我将非常感激

from keras.layers import Dense
from sklearn.model_selection import train_test_split
from keras.models import Model
from keras.layers import Input 
import tensorflow as tf
import os,numpy


# random seed for reproducibility

numpy.random.seed(101)
tf.random.set_seed(101)
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

# loading load pima indians diabetes dataset,past 5 years of medical history
dataset = numpy.loadtxt('https://raw.githubusercontent.com/jbrownlee/Datasets/master/pima-indians-diabetes.data.csv',delimiter=",")
# split into input (X) and output (Y) variables,splitting csv data
X = dataset[:,0:8]
Y = dataset[:,8]
x_train,x_validation,y_train,y_validation = train_test_split(
    X,Y,test_size=0.20,random_state=5)

# Define Model A 
input_layer = Input(shape=(8,))
A2 = Dense(10,activation='relu')(input_layer)
A3 = Dense(50,activation='relu')(A2)
A4 = Dense(50,activation='relu')(A3)
A5 = Dense(50,activation='relu')(A4)
A6 = Dense(50,activation='relu')(A5)
A7 = Dense(50,activation='relu')(A6)
A8 = Dense(10,activation='relu')(A7)
A9 = Dense(5,activation='relu')(A8)
model_a = Model(inputs=input_layer,outputs=A9,name="ModelA")

# Define Model B 
input_layer = Input(shape=(8,))
B2 = Dense(13,activation='relu')(input_layer)
B3 = Dense(12,activation='relu')(B2)
B4 = Dense(16,activation='relu')(B3)
B5 = Dense(16,activation='relu')(B4)
B6 = Dense(14,activation='relu')(B5)
B7 = Dense(5,activation='relu')(B6)
model_b = Model(inputs=input_layer,outputs=B7,name="ModelB")

# Define Model C
input_layer = Input(shape=(8,))
C2 = Dense(33,activation='relu')(input_layer)
C3 = Dense(23,activation='relu')(C2)
C4 = Dense(17,activation='relu')(C3)
C5 = Dense(9,activation='relu')(C4)
C6 = Dense(17,activation='relu')(C5)
C7 = Dense(13,activation='relu')(C6)
C8 = Dense(9,activation='relu')(C7)
C9 = Dense(9,activation='relu')(C8)
C10 = Dense(5,activation='relu')(C9)
model_c = Model(inputs=input_layer,outputs=C10,name="ModelC")
all_three_models = [model_a,model_b,model_c]
all_three_models_input = Input(shape=all_three_models[0].input_shape[1:])


models_output = [model(all_three_models_input) for model in all_three_models]
Concat           = tf.keras.layers.concatenate(models_output,name="Concatenate")
final_out     = Dense(1,activation='sigmoid')(Concat)

final_model   = Model(inputs=all_three_models_input,outputs=final_out,name='Ensemble')

# call the function to fit to the data (training the network)
final_model.compile(loss="binary_crossentropy",optimizer="adam",metrics=['accuracy'])

model_save = tf.keras.callbacks.ModelCheckpoint(
                'merge_best.h5',monitor="val_accuracy",verbose=0,save_best_only=True,save_weights_only=True,mode="max",save_freq="epoch"
            )

# call the function to fit to the data (training the network)
final_model.fit(x_train,epochs=1000,batch_size=256,callbacks=[model_save],validation_data=(x_validation,y_validation))


# evaluate the model
final_model.load_weights('merge_best.h5')
scores = final_model.evaluate(x_validation,y_validation)
print("\n%s: %.2f%%" % (final_model.metrics_names[1],scores[1] * 100))
final_model.save('diabetes_risk_nn.h5')

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?