如何解决列出测试数据中的错误预测! - Python
我试图在测试集中列出所有错误的预测,但不确定如何去做。我尝试过 Stackoverflow,但可能搜索了错误的“问题”。所以我从一个文件夹中获得了这些文本文件,其中包含电子邮件。问题是我的预测效果不佳,我想检查预测错误的电子邮件。目前我的代码片段看起来像这样:
no_head_train_path_0 = 'folder_name'
no_head_train_path_1 = 'folder_name'
def get_data(path):
text_list = list()
files = os.listdir(path)
for text_file in files:
file_path = os.path.join(path,text_file)
read_file = open(file_path,'r+')
read_text = read_file.read()
read_file.close()
cleaned_text = clean_text(read_text)
text_list.append(cleaned_text)
return text_list,files
no_head_train_0,temp = get_data(no_head_train_path_0)
no_head_train_1,temp1 = get_data(no_head_train_path_1)
no_head_train = no_head_train_0 + no_head_train_1
no_head_labels_train = ([0] * len(no_head_train_0)) + ([1] * len(no_head_train_1))
def vocabularymat(TEXTFILES,VOC,PLAY,METHOD):
from sklearn.feature_extraction.text import CountVectorizer,TfidfVectorizer
if (METHOD == "TDM"):
voc = CountVectorizer()
voc.fit(VOC)
if (PLAY == "TRAIN"):
TrainMat = voc.transform(TEXTFILES)
return TrainMat
if (PLAY =="TEST"):
TestMat = voc.transform(TEXTFILES)
return TestMat
TrainMat = vocabularymat(no_head_train,no_head_train,PLAY= "TRAIN",METHOD="TDM")
X_train = Featurelearning(Traindata,Method="NMF")
y_train = datalabel
X_train,X_test,y_train,y_test = train_test_split(data,datalabel,test_size=0.33,random_state=42
model = LogisticRegression()
model.fit(X_train,y_train)
expected = y_test
predicted = model.predict(X_test)
proba = model.predict_proba(X_test)
accuracy = accuracy_score(expected,predicted)
recall = recall_score(expected,predicted,average="binary")
precision = precision_score(expected,average="binary")
f1 = f1_score(expected,average="binary")
是否可以找到预测错误的电子邮件/文件名,以便我可以手动检查它们? (抱歉代码太长)
解决方法
您可以使用 NumPy 创建一个布尔向量,指示哪些预测是错误的,然后使用该向量来索引您的文件名数组。例如:
import numpy as np
# mock data
files = np.array(['mail1.txt','mail2.txt','mail3.txt','mail4.txt'])
y_test = np.array([0,1,1])
predicted = np.array([0,1])
# create a Boolean index for the wrong classifications
classification_is_wrong = y_test != predicted
# print the file names of the wrongly classified mails
print(files[classification_is_wrong])
输出:
['mail2.txt' 'mail3.txt']
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。