当我使用Scikit-Learn Keras模型函数时,对于相同的网络配置,为什么会有不同的精度结果?

如何解决当我使用Scikit-Learn Keras模型函数时,对于相同的网络配置,为什么会有不同的精度结果?

我在构建DNN时使用了Keras的scikit-learn分类器API,即“ tf.keras.wrappers.scikit_learn.KerasClassifier”。我的平均简历得分为53%。当我不使用Keraswrapper函数执行相同的分类时,尽管我使用了相同的体系结构和超参数,但我的平均cv得分为24.23%。我遵循了Jason brownlee的“ Python深度学习”一书中的代码。不使用包装函数,我的代码是:

from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import StratifiedKFold
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)

kfold = StratifiedKFold(n_splits=5,shuffle=True,random_state=seed)
cvscores = []

for train,test in kfold.split(X,y):
 model = Sequential()
 model.add(Dense(128,input_dim=76636,kernel_initializer='uniform',activation='relu'))
 model.add(Dense(64,activation='relu',kernel_initializer='uniform'))  
 model.add(Dense(2,activation='softmax'))
 # Compile model
 model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
 #Fit the model
 model.fit(X[train],y[train],epochs=50,batch_size=512,verbose=0)
 #Evaluate the model
 scores = model.evaluate(X[test],y[test],verbose=0)
 #print("%s: %.2f%%" % (model.metrics_names[1],scores[1]*100))
 cvscores.append(scores[1] * 100)

print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores),numpy.std(cvscores)))

我得到以下输出:24.23%(+/- 2.35%)

当我使用Keraswrapper函数时,我的代码是:

from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
import numpy

# Function to create model,required for KerasClassifier
def create_model():
 # create model
 model = Sequential()
 model.add(Dense(128,metrics=['accuracy'])
 return model

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)

model = KerasClassifier(build_fn=create_model,nb_epoch=50,verbose=0)

# evaluate using 10-fold cross validation
kfold = StratifiedKFold(n_splits=5,random_state=seed)
results = cross_val_score(model,X,y,cv=kfold)
print(results.mean())

输出为:0.5315796375274658

解决方法

我在pima-indians-diabetes.csv数据集上运行了您的代码,但无法重新创建您面临的问题。结果略有差异,但可以用numpy.std(cvscores)来解释。

下面是运行详细信息-

不使用包装器功能:

%tensorflow_version 2.x
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import StratifiedKFold
import numpy

# load pima indians dataset
dataset = np.loadtxt("/content/pima-indians-diabetes.csv",delimiter=",")

# split into input (X) and output (Y) variables
X = dataset[:,0:8]
y = dataset[:,8]

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)

kfold = StratifiedKFold(n_splits=5,shuffle=True,random_state=seed)
cvscores = []

for train,test in kfold.split(X,y):
 model = Sequential()
 model.add(Dense(128,input_dim=8,kernel_initializer='uniform',activation='relu'))
 model.add(Dense(64,activation='relu',kernel_initializer='uniform'))  
 model.add(Dense(2,activation='softmax'))
 # Compile model
 model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
 #Fit the model
 model.fit(X[train],y[train],epochs=50,batch_size=512,verbose=0)
 #Evaluate the model
 scores = model.evaluate(X[test],y[test],verbose=0)
 #print("%s: %.2f%%" % (model.metrics_names[1],scores[1]*100))
 cvscores.append(scores[1] * 100)

print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores),numpy.std(cvscores)))

输出-

WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_train_function.<locals>.train_function at 0x7fc2c1468ae8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,(2) passing tensors with different shapes,(3) passing Python objects instead of tensors. For (1),please define your @tf.function outside of the loop. For (2),@tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3),please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:7 out of the last 30 calls to <function Model.make_test_function.<locals>.test_function at 0x7fc2c4579d90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:7 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7fc2c1604d08> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
48.70% (+/- 5.46%)

使用包装器功能:

from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
import numpy

# load pima indians dataset
dataset = np.loadtxt("/content/pima-indians-diabetes.csv",8]

# Function to create model,required for KerasClassifier
def create_model():
 # create model
 model = Sequential()
 model.add(Dense(128,metrics=['accuracy'])
 return model

# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)

model = KerasClassifier(build_fn=create_model,nb_epoch=50,verbose=0)

# evaluate using 10-fold cross validation
kfold = StratifiedKFold(n_splits=5,random_state=seed)
results = cross_val_score(model,X,y,cv=kfold)
print(results.mean())

输出-

WARNING:tensorflow:5 out of the last 13 calls to <function Model.make_test_function.<locals>.test_function at 0x7fc2c4b79ea0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:5 out of the last 107 calls to <function Model.make_train_function.<locals>.train_function at 0x7fc2d31262f0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 14 calls to <function Model.make_test_function.<locals>.test_function at 0x7fc2c4b79a60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:6 out of the last 109 calls to <function Model.make_train_function.<locals>.train_function at 0x7fc2c4cc4268> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
WARNING:tensorflow:7 out of the last 15 calls to <function Model.make_test_function.<locals>.test_function at 0x7fc2c15a4268> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
0.4320685863494873

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?