如何解决当我使用Scikit-Learn Keras模型函数时,对于相同的网络配置,为什么会有不同的精度结果?
我在构建DNN时使用了Keras的scikit-learn分类器API,即“ tf.keras.wrappers.scikit_learn.KerasClassifier”。我的平均简历得分为53%。当我不使用Keraswrapper函数执行相同的分类时,尽管我使用了相同的体系结构和超参数,但我的平均cv得分为24.23%。我遵循了Jason brownlee的“ Python深度学习”一书中的代码。不使用包装函数,我的代码是:
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import StratifiedKFold
import numpy
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
kfold = StratifiedKFold(n_splits=5,shuffle=True,random_state=seed)
cvscores = []
for train,test in kfold.split(X,y):
model = Sequential()
model.add(Dense(128,input_dim=76636,kernel_initializer='uniform',activation='relu'))
model.add(Dense(64,activation='relu',kernel_initializer='uniform'))
model.add(Dense(2,activation='softmax'))
# Compile model
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
#Fit the model
model.fit(X[train],y[train],epochs=50,batch_size=512,verbose=0)
#Evaluate the model
scores = model.evaluate(X[test],y[test],verbose=0)
#print("%s: %.2f%%" % (model.metrics_names[1],scores[1]*100))
cvscores.append(scores[1] * 100)
print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores),numpy.std(cvscores)))
我得到以下输出:24.23%(+/- 2.35%)
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
import numpy
# Function to create model,required for KerasClassifier
def create_model():
# create model
model = Sequential()
model.add(Dense(128,metrics=['accuracy'])
return model
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
model = KerasClassifier(build_fn=create_model,nb_epoch=50,verbose=0)
# evaluate using 10-fold cross validation
kfold = StratifiedKFold(n_splits=5,random_state=seed)
results = cross_val_score(model,X,y,cv=kfold)
print(results.mean())
输出为:0.5315796375274658
解决方法
我在pima-indians-diabetes.csv
数据集上运行了您的代码,但无法重新创建您面临的问题。结果略有差异,但可以用numpy.std(cvscores)
来解释。
下面是运行详细信息-
不使用包装器功能:
%tensorflow_version 2.x
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import StratifiedKFold
import numpy
# load pima indians dataset
dataset = np.loadtxt("/content/pima-indians-diabetes.csv",delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
y = dataset[:,8]
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
kfold = StratifiedKFold(n_splits=5,shuffle=True,random_state=seed)
cvscores = []
for train,test in kfold.split(X,y):
model = Sequential()
model.add(Dense(128,input_dim=8,kernel_initializer='uniform',activation='relu'))
model.add(Dense(64,activation='relu',kernel_initializer='uniform'))
model.add(Dense(2,activation='softmax'))
# Compile model
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
#Fit the model
model.fit(X[train],y[train],epochs=50,batch_size=512,verbose=0)
#Evaluate the model
scores = model.evaluate(X[test],y[test],verbose=0)
#print("%s: %.2f%%" % (model.metrics_names[1],scores[1]*100))
cvscores.append(scores[1] * 100)
print("%.2f%% (+/- %.2f%%)" % (numpy.mean(cvscores),numpy.std(cvscores)))
输出-
WARNING:tensorflow:6 out of the last 11 calls to <function Model.make_train_function.<locals>.train_function at 0x7fc2c1468ae8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,(2) passing tensors with different shapes,(3) passing Python objects instead of tensors. For (1),please define your @tf.function outside of the loop. For (2),@tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3),please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:7 out of the last 30 calls to <function Model.make_test_function.<locals>.test_function at 0x7fc2c4579d90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:7 out of the last 11 calls to <function Model.make_test_function.<locals>.test_function at 0x7fc2c1604d08> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
48.70% (+/- 5.46%)
使用包装器功能:
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
import numpy
# load pima indians dataset
dataset = np.loadtxt("/content/pima-indians-diabetes.csv",8]
# Function to create model,required for KerasClassifier
def create_model():
# create model
model = Sequential()
model.add(Dense(128,metrics=['accuracy'])
return model
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
model = KerasClassifier(build_fn=create_model,nb_epoch=50,verbose=0)
# evaluate using 10-fold cross validation
kfold = StratifiedKFold(n_splits=5,random_state=seed)
results = cross_val_score(model,X,y,cv=kfold)
print(results.mean())
输出-
WARNING:tensorflow:5 out of the last 13 calls to <function Model.make_test_function.<locals>.test_function at 0x7fc2c4b79ea0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 107 calls to <function Model.make_train_function.<locals>.train_function at 0x7fc2d31262f0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 14 calls to <function Model.make_test_function.<locals>.test_function at 0x7fc2c4b79a60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 109 calls to <function Model.make_train_function.<locals>.train_function at 0x7fc2c4cc4268> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:7 out of the last 15 calls to <function Model.make_test_function.<locals>.test_function at 0x7fc2c15a4268> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop,please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
0.4320685863494873
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。