如何解决无法在尺度上自定义指标
我尝试按照此处的 yardstick 步骤自定义指标:https://yardstick.tidymodels.org/articles/custom-metrics.html,但不幸的是我未能找到解决问题的方法。
library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.0.3
#> -- Attaching packages -------------------------------------- tidymodels 0.1.2 --
#> v broom 0.7.2 v recipes 0.1.15
#> v dials 0.0.9 v rsample 0.0.8
#> v dplyr 1.0.2 v tibble 3.0.4
#> v ggplot2 3.3.2 v tidyr 1.1.2
#> v infer 0.5.3 v tune 0.1.2
#> v modeldata 0.1.0 v workflows 0.2.1
#> v parsnip 0.1.4 v yardstick 0.0.7
#> v purrr 0.3.4
#> Warning: package 'broom' was built under R version 4.0.3
#> Warning: package 'modeldata' was built under R version 4.0.3
#> Warning: package 'parsnip' was built under R version 4.0.3
#> Warning: package 'recipes' was built under R version 4.0.3
#> Warning: package 'tibble' was built under R version 4.0.3
#> Warning: package 'tune' was built under R version 4.0.3
#> Warning: package 'workflows' was built under R version 4.0.3
#> -- Conflicts ----------------------------------------- tidymodels_conflicts() --
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
library(glmnet)
#> Loading required package: Matrix
#>
#> Attaching package: 'Matrix'
#> The following objects are masked from 'package:tidyr':
#>
#> expand,pack,unpack
#> Loaded glmnet 4.0-2
library(doParallel)
#> Loading required package: foreach
#> Warning: package 'foreach' was built under R version 4.0.3
#>
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#>
#> accumulate,when
#> Loading required package: iterators
#> Warning: package 'iterators' was built under R version 4.0.3
#> Loading required package: parallel
library(rlang)
#> Warning: package 'rlang' was built under R version 4.0.3
#>
#> Attaching package: 'rlang'
#> The following objects are masked from 'package:purrr':
#>
#> %@%,as_function,flatten,flatten_chr,flatten_dbl,flatten_int,#> flatten_lgl,flatten_raw,invoke,list_along,modify,prepend,#> splice
library(reprex)
#> Warning: package 'reprex' was built under R version 4.0.3
# Custom function
mse_vec <- function(truth,estimate,na_rm = TRUE,...) {
mse_impl <- function(truth,estimate) {
mean((truth - estimate) ^ 2)
}
metric_vec_template(
metric_impl = mse_impl,truth = truth,estimate = estimate,na_rm = na_rm,cls = "numeric",...
)
}
mse <- function(data,...) {
UseMethod("mse")
}
mse.data.frame <- function(data,truth,...) {
metric_summarizer(
metric_nm = "mse",metric_fn = mse_vec,data = data,truth = !! enquo(truth),estimate = !! enquo(estimate),...
)
}
class(mse) <- c("numeric_metric",class(mse))
# Loading data
set.seed(1)
dt <- data.frame(
X = sample(1:1000),y = rnorm(n = 5,mean = 0,sd = 0.75),z = rnorm(n = 10,mean = 5,sd = 0.25)
)
set.seed(123)
dt_splits <- initial_split(dt,prop = 0.7,strata = y)
dt_train <- training(dt_splits)
dt_test <- testing(dt_splits)
dt_rec <- recipe(y ~ .,data = dt_train) %>%
step_zv(all_predictors()) %>%
step_nzv(all_predictors()) %>%
step_bagimpute(all_predictors(),seed_val = sample.int(10^4,1)) %>%
step_normalize(all_numeric(),-all_outcomes()) %>%
step_corr(all_predictors(),threshold = .99)
# Apply processing to test and training data
dt_baked_train <- dt_rec %>% prep() %>% bake(dt_train) # Preprocessed training
dt_baked_test <- dt_rec %>% prep() %>% bake(dt_test) # Preprocessed testing
# Build the model
cv_splits <- vfold_cv(dt_train,v = 5)
outcome = "y"
preds <- names(dt_train)[!names(dt_train) %in% outcome]
en_mod <- linear_reg(mode = "regression",penalty = tune(),mixture = tune()) %>% set_engine("glmnet")
en_wf <- workflow() %>% add_recipe(dt_rec) %>% add_model(en_mod)
en_set <- parameters(penalty(range = c(-10,0),trans = log10_trans()),mixture())
set.seed(100)
en_grid <- grid_latin_hypercube(en_set,size = 100)
en_grid
#> # A tibble: 100 x 2
#> penalty mixture
#> <dbl> <dbl>
#> 1 0.200 0.425
#> 2 0.0000598 0.516
#> 3 0.00000929 0.595
#> 4 0.0000117 0.572
#> 5 0.000130 0.472
#> 6 0.0000000118 0.0378
#> 7 0.00000413 0.306
#> 8 0.00000000117 0.563
#> 9 0.0000749 0.878
#> 10 0.0000169 0.742
#> # ... with 90 more rows
en_ctrl <- control_grid(save_pred = TRUE,verbose = F)
perf_metrics <- metric_set(rmse,rsq,ccc,mse) # custom metric is added here
set.seed(200)
cl <- makePSOCKcluster(8)
registerDoParallel(cl)
en_tune <- en_wf %>% tune_grid(resamples = cv_splits,grid = en_grid,metrics = perf_metrics,control = en_ctrl)
#> Warning: All models Failed. See the `.notes` column.
stopCluster(cl)
en_tune$.notes
#> [[1]]
#> # A tibble: 100 x 1
#> .notes
#> <chr>
#> 1 preprocessor 1/1,model 1/100 (predictions): Error: Result 4 must be a singl~
#> 2 preprocessor 1/1,model 2/100 (predictions): Error: Result 4 must be a singl~
#> 3 preprocessor 1/1,model 3/100 (predictions): Error: Result 4 must be a singl~
#> 4 preprocessor 1/1,model 4/100 (predictions): Error: Result 4 must be a singl~
#> 5 preprocessor 1/1,model 5/100 (predictions): Error: Result 4 must be a singl~
#> 6 preprocessor 1/1,model 6/100 (predictions): Error: Result 4 must be a singl~
#> 7 preprocessor 1/1,model 7/100 (predictions): Error: Result 4 must be a singl~
#> 8 preprocessor 1/1,model 8/100 (predictions): Error: Result 4 must be a singl~
#> 9 preprocessor 1/1,model 9/100 (predictions): Error: Result 4 must be a singl~
#> 10 preprocessor 1/1,model 10/100 (predictions): Error: Result 4 must be a sing~
#> # ... with 90 more rows
#>
#> [[2]]
#> # A tibble: 100 x 1
#> .notes
#> <chr>
#> 1 preprocessor 1/1,model 10/100 (predictions): Error: Result 4 must be a sing~
#> # ... with 90 more rows
#>
#> [[3]]
#> # A tibble: 100 x 1
#> .notes
#> <chr>
#> 1 preprocessor 1/1,model 10/100 (predictions): Error: Result 4 must be a sing~
#> # ... with 90 more rows
#>
#> [[4]]
#> # A tibble: 100 x 1
#> .notes
#> <chr>
#> 1 preprocessor 1/1,model 10/100 (predictions): Error: Result 4 must be a sing~
#> # ... with 90 more rows
#>
#> [[5]]
#> # A tibble: 100 x 1
#> .notes
#> <chr>
#> 1 preprocessor 1/1,model 10/100 (predictions): Error: Result 4 must be a sing~
#> # ... with 90 more rows
由 reprex package (v0.3.0) 于 2020 年 12 月 27 日创建
我曾经在插入符号中制作自定义指标,但鉴于上述错误,我不知道如何使用 yardstick 自定义此指标。如果有人能帮助我,我将不胜感激。谢谢。
# Summary metrics of trained models -----------------------------------------
custom_summary <- function (data,lev = NULL,model = NULL) {
if(length(unique(data$pred)) < 2 || length(unique(data$obs)) < 2) {
resamplCor <- NA } else { resamplCor <- try(cor(data$pred,data$obs,use =
"pairwise.complete.obs"),silent = TRUE)
if (inherits(resamplCor,"try-error")) resamplCor <- NA }
rmse <- hydroGOF::rmse(sim = data$pred,obs = data$obs,na.rm=TRUE)
nrmse <- hydroGOF::nrmse(sim = data$pred,norm = "maxmin",na.rm=TRUE)
pred_error <- (data$obs - data$pred)
apz <- 100*(length(which(pred_error > -1 & pred_error < 0.5))/length(pred_error))
out <- c(nrmse,resamplCor^2,apz,rmse)
names(out) <- c("NRMSE","Rsquared","APZ","RMSE")
out
}
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。