微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

Tidymodels 工作流使用 add_formula() 或 add_variables() 但不使用 add_recipe()

如何解决Tidymodels 工作流使用 add_formula() 或 add_variables() 但不使用 add_recipe()

我在使用 naiveBayes 分类器从有效文本中区分垃圾邮件方法和工作流程中遇到了一些奇怪的行为。我试图使用 tidymodels 和工作流复制机器学习与 R 一书的第 4 章的结果:https://github.com/PacktPublishing/Machine-Learning-with-R-Second-Edition/blob/master/Chapter%2004/MLwR_v2_04.r

虽然我可以使用 add_variables()add_formula() 或不使用工作流重现分析,但使用 add_recipe() 函数的工作流不起作用。

library(RCurl)
library(tidyverse)
library(tidymodels)
library(textrecipes)
library(tm)
library(snowballC) 
library(discrim) 


sms_raw <- getURL("https://raw.githubusercontent.com/stedy/Machine-Learning-with-R-datasets/master/sms_spam.csv")
sms_raw <- read_csv(sms_raw)
sms_raw$type <- factor(sms_raw$type)

set.seed(123)
split <- initial_split(sms_raw,prop = 0.8,strata = "type")
nb_train_sms <- training(split)
nb_test_sms <- testing(split)

# Text preprocessing
reci_sms <- 
  recipe(type ~.,data = nb_train_sms) %>% 
  step_mutate(text = str_to_lower(text)) %>% 
  step_mutate(text = removeNumbers(text)) %>% 
  step_mutate(text = removePunctuation(text)) %>% 
  step_tokenize(text) %>% 
  step_stopwords(text,custom_stopword_source = stopwords()) %>% 
  step_stem(text) %>% 
  step_tokenfilter(text,min_times = 6,max_tokens = 1500) %>% 
  step_tf(text,weight_scheme = "binary") %>% 
  step_mutate_at(contains("tf"),fn =function(x){ifelse(x == TRUE,"Yes","No")}) %>% 
  prep()


df_training <- juice(reci_sms)
df_testing <- bake(reci_sms,new_data = nb_test_sms)

nb_model <- naive_Bayes() %>% 
  set_engine("klaR") 

以下是三个实际产生有效输出代码示例

# --------- works but slow -----
nb_fit <- nb_fit <- workflow() %>%
  add_model(nb_model) %>%
  add_formula(type~.) %>%
  fit(df_training)
nb_tidy_pred <- nb_fit %>% predict(df_testing)


# --------- works  -----
nb_fit <- nb_model %>% fit(type ~.,df_training)
nb_tidy_pred <- nb_fit %>% predict(df_testing)


# --------- works  -----

nb_fit <- workflow() %>%
  add_model(nb_model) %>%
  add_variables(outcomes = type,predictors = everything()) %>%
  fit(df_training)

nb_tidy_pred <- nb_fit %>% predict(df_testing)

虽然下面的代码不起作用

nb_fit <- workflow() %>%
  add_model(nb_model) %>%
  add_recipe(reci_sms) %>%
  fit(data = df_training)

nb_tidy_pred <- nb_fit %>% predict(df_testing)

它也抛出以下错误,但我真的不明白使用 rlang::last_error() 时发生了什么

Not all variables in the recipe are present in the supplied training set: 'text'.
Run `rlang::last_error()` to see where the error occurred.

谁能告诉我我错过了什么?

解决方法

当您在工作流程中使用配方时,您可以将预处理步骤与模型拟合相结合。在拟合该工作流程时,您需要使用配方预期的数据 (nb_train_sms),而不是欧洲防风草模型预期的数据。

此外,它是 not recommended to pass a prepped recipe to a workflow,所以在使用 prep() 将其添加到工作流程之前,看看我们如何不add_recipe()

library(RCurl)
library(tidyverse)
library(tidymodels)
library(textrecipes)
library(tm) 
library(discrim)

sms_raw <- getURL("https://raw.githubusercontent.com/stedy/Machine-Learning-with-R-datasets/master/sms_spam.csv")
sms_raw <- read_csv(sms_raw)
sms_raw$type <- factor(sms_raw$type)

set.seed(123)
split <- initial_split(sms_raw,prop = 0.8,strata = "type")
nb_train_sms <- training(split)
nb_test_sms <- testing(split)

# Text preprocessing
reci_sms <- 
  recipe(type ~.,data = nb_train_sms) %>% 
  step_mutate(text = str_to_lower(text)) %>% 
  step_mutate(text = removeNumbers(text)) %>% 
  step_mutate(text = removePunctuation(text)) %>% 
  step_tokenize(text) %>% 
  step_stopwords(text,custom_stopword_source = stopwords()) %>% 
  step_stem(text) %>% 
  step_tokenfilter(text,min_times = 6,max_tokens = 1500) %>% 
  step_tf(text,weight_scheme = "binary")  %>% 
  step_mutate_at(contains("tf"),fn = function(x){ifelse(x == TRUE,"Yes","No")})

nb_model <- naive_Bayes() %>% 
  set_engine("klaR") 

nb_fit <- workflow() %>%
  add_model(nb_model) %>%
  add_recipe(reci_sms) %>%
  fit(data = nb_train_sms)
#> Warning: max_features was set to '1500',but only 1141 was available and
#> selected.

nb_tidy_pred <- nb_fit %>% predict(nb_train_sms)

reprex package (v1.0.0) 于 2021 年 4 月 19 日创建

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?