如何解决如何在模型输出中列出每个预测的概率
使用修改后的鸢尾花数据集测试一些用于预测物种的模型。目前仅限于 SVM 和随机森林。在 R-studio 中运行。
简化设置:
library(caret)
#data
data(iris)
#rename
dataset <- iris
#smaller sample
sample_data <- dataset[sample(nrow(dataset),60),]
#create some noise so model is less-than-perfect
noise_df <- data.frame(
Sepal.Length = c(5.7,5.7,7.0,5.0,5.0),Sepal.Width = c(3.8,3.8,2.7,2.8,3.1,3.1),Petal.Length = c(5.2,5.2,5.3,5.4,5.5,5.6,5.8,1.3,1.3),Petal.Width = c(1.8,1.8,1.9,2.0,0.2,0.2),Species = c("setosa","setosa","virginica","virginica")
)
#combine sample with noise
dataset2 <- rbind(sample_data,noise_df)
#split data into train/test
set.seed(7)
validation_index <- createDataPartition(dataset2$Species,p=0.70,list=FALSE)
test_set <- dataset2[-validation_index,]
train_set <- dataset2[validation_index,]
#====================
#build models
#====================
control <- trainControl(method="cv",number=10)
metric <- "Accuracy"
#random forest model
set.seed(3)
fit.rf <- train(Species~.,data=train_set,method="rf",metric=metric,trControl=control)
#svm model
set.seed(3)
fit.svm <- train(Species~.,method="svmRadial",trControl=control)
#====================
#run model on test
#====================
predictions <- predict(fit.svm,test_set)
confusionMatrix(predictions,test_set$Species)
混淆矩阵输出:
Reference
Prediction setosa versicolor virginica
setosa 11 0 3
versicolor 0 3 0
virginica 0 1 5
我想知道是否可以列出每个预测的概率。例如:
setosa versicolor virginica predicted
1 0.9 0.0 0.1 setosa
2 0.1 0.8 0.1 versicolor
3 0.33 0.33 0.33 virginica
我猜随机森林可能只列出 0 对 1,但想知道 SVM 是否可以像上面的例子那样分解概率。如果是这样,我不确定如何塑造我的数据或要使用的函数。它是decision_function 还是predict_proba 函数,但我不清楚如何在r 中正确执行它。
解决方法
对于随机森林,概率是预测每个标签的决策树的比例,您可以使用predict(..,type="prob")
:
data.frame(predict(fit.rf,type="prob",newdata=test_set),predicted=predict(fit.rf,newdata=test_set))
setosa versicolor virginica predicted
147 0.016 0.002 0.982 virginica
15 0.908 0.068 0.024 setosa
103 0.486 0.000 0.514 virginica
118 0.416 0.056 0.528 virginica
129 0.344 0.000 0.656 virginica
39 0.388 0.080 0.532 virginica
对于 kernlab svm,您需要设置 prob.model = TRUE
:
set.seed(3)
fit.svm <- train(Species~.,data=train_set,method="svmRadial",metric=metric,trControl=control,prob.model = TRUE)
data.frame(predict(fit.svm,newdata=test_set,type="prob"),predicted=predict(fit.svm,newdata=test_set))
setosa versicolor virginica predicted
1 0.129916071 0.051873046 0.81821088 virginica
2 0.884025291 0.030853736 0.08512097 setosa
3 0.129054108 0.006256384 0.86468951 virginica
4 0.104952659 0.124066424 0.77098092 virginica
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。