Tidymodels:具有自定义数据拆分的奇怪错误消息

如何解决Tidymodels:具有自定义数据拆分的奇怪错误消息

我正在学习新的tidymodels框架的绳索,因此我可能会误解一些基本知识。

我提供了一个独立的示例,其中包含真实的数据集(从我的工作中得出)。 鉴于我需要使用除最近的观测值以外的所有观测值作为训练集,而仅将最近的观测值用作测试集(因此在这种情况下,测试集仅是观测值)。

但是,我收到一个我无法解读的错误。任何建议表示赞赏。

谢谢!

library(tidyverse) 

library(tidymodels)


df_ini <- structure(list(year = c(1998,2002,2004,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018),capital_n1132g_lag_1 = c(3446.5,4091.1,3655.1,3633.3,3616.2,3450.7,3596.8,3867.2,3372.5,3722.9,3808.5,4005.6,3718.6,3467.9,4214.2,4237.4,4450.2),capital_n117g_lag_1 = c(4920.9,7810.6,8560.3,8679.9,8938.9,9823.8,10467.1,11047.1,11554.3,11849.9,13465.4,13927.5,15510.2,15754.4,16584.7,17647.1,18273.8),capital_n11mg_lag_1 = c(16846,19605,19381.2,19433.5,20051.6,20569.8,22646.1,23674.5,21200.6,20919.6,23157.7,23520.7,24057.7,23832.8,25019.2,27608.2,29790.1),employment_be_lag_1 = c(2834.42,2839.72,2765.53,2731.08,2709.59,2708.39,2774.06,2795.6,2703.36,2668.1,2705.1,2731.67,2727.16,2725.66,2735.69,2750.52,2782.9
    ),employment_c_lag_1 = c(2612.76,2623.69,2552.89,2518.57,2496.98,2499.54,2558.88,2578,2483.97,2447.65,2483.1,2507.41,2500.94,2499.6,2511.75,2523.97,2555.48),employment_j_lag_1 = c(292.93,389.2,389.45,387.53,384.64,389.29,385.77,392.86,383.91,392.18,410.85,419.75,427.59,438.96,440.33,460.84,473.4
    ),employment_k_lag_1 = c(505.33,507.12,510.25,504.63,515.39,523.45,536.6,550.14,546.68,539.96,536.58,534.98,524.13,518.89,511.57,505.32,496.41),employment_mn_lag_1 = c(945.59,1217.96,1289.55,1365.29,1425.81,1537.88,1622.95,1727.76,1704.65,1762.55,1838.16,1896.09,1929.09,1950.02,1968.83,2021.51,2109.71),employment_oq_lag_1 = c(3065.87,3191.75,3280.36,3317.09,3401.65,3476.63,3508.01,3577.75,3683.85,3759.23,3798.35,3850.17,3877.24,3924.06,4002.74,4095.59,4171.72),employment_total_lag_1 = c(14509.58,15127.99,15212.11,15307.28,15491.61,15762.92,16050.92,16356.53,16269.97,16392.87,16647.79,16820.66,16879.06,17039.6,17142.13,17365.32,17650.21),gdp_b1gq_lag_1 = c(187849.7,220525,231862.5,242348.3,254075,267824.4,283978,293761.9,288044.1,295896.6,310128.6,318653.1,323910.2,333146.1,344269.3,357608,369341.3),gdp_p3_lag_1 = c(139695.2,161175.8,169405.6,176316.4,185871.1,194102,200944.4,208857.1,213630.1,218947.2,227250.8,233638.1,238329.3,243860.6,249404.3,257166.5,265900.2),gdp_p61_lag_1 = c(50117.6,71948.6,74346.9,83074.9,90010.4,100076.8,110157.2,113368.1,91435.3,111997.3,123526.3,125801.2,123657.1,126109.3,129183.6,131524,140057.8),gdp_p62_lag_1 = c(19441,26444.4,28995.1,30507,33520.2,36089.5,39104,43056.8,38781.9,39685.8,43784.1,46187.6,49444.7,51746,53585.8,55885.5,59584.7),price_index_lag_1 = c(1.2,2.3,1.3,2,2.1,1.7,2.2,3.2,0.4,3.6,2.6,1.5,0.8,1,2.2),value_be_lag_1 = c(40533.1,48207.1,48673.2,50737.6,52955.2,56872.4,60864.9,61029,56837.8,58433.6,61443,63655.1,64132.3,65542.6,67495.4,71152.6,72698.8),value_c_lag_1 = c(33441.8,40446.6,40467.4,42014.6,44229,47735.5,51552.4,51165.9,47129.7,48759.3,51467.7,53234.6,53431.4,55169,57458.7,60962.8,62196
    ),value_j_lag_1 = c(5483.7,7326.1,7934.1,7756.1,8134.2,8378.8,8532.3,8740,8493.9,8518.9,9217.1,9405.1,9802.1,10361.4,10695.4,11455.3,11720.6),value_k_lag_1 = c(9210.6,9977.3,10146.9,10541.9,11005.3,11912.3,13102.7,13205.2,12123.9,12113.2,12952.8,12254.9,12796.6,12962.4,13482.9,13236.4,13744.1),value_mn_lag_1 = c(10444,14061.4,15706.6,16569.1,18008.7,19576.6,21317,23189.8,22490,23255.2,24895.4,25988.7,26998.2,28027.3,29207.9,30737.7,32259.6
    ),value_oq_lag_1 = c(29902.7,34179.2,36126.8,37329.6,38288.8,40003.1,41511.4,43761.3,45817.8,46996.6,47980.9,49381.5,50261.7,51624.3,53715,55926.4,57637.1),value_total_lag_1 = c(167323.4,197076.7,207247.6,216098.3,225888.1,239076,253604.6,262414.7,256671,263633.5,276404,283548.2,288624.3,297230.1,307037.7,318952.7,329396.1),capital_n1132g_lag_2 = c(3599.2,3996.9,3638.4,4237.4
    ),capital_n117g_lag_2 = c(4636.2,7008.5,8369.6,17647.1),capital_n11mg_lag_2 = c(17181.5,19677.8,18749.6,27608.2),employment_be_lag_2 = c(2870.33,2840.19,2775.22,2750.52),employment_c_lag_2 = c(2626.2,2621.08,2562.53,2523.97
    ),employment_j_lag_2 = c(275.08,374.56,400.75,460.84),employment_k_lag_2 = c(500.9,505.13,502.42,505.32
    ),employment_mn_lag_2 = c(904.38,1143.78,1248.01,2021.51),employment_oq_lag_2 = c(3028.85,3162.77,3241.36,4095.59),employment_total_lag_2 = c(14404.29,15019.87,15113.52,17365.32),gdp_b1gq_lag_2 = c(186928.7,213606.4,226735.3,357608),gdp_p3_lag_2 = c(140335.8,156117.3,164107.8,257166.5),gdp_p61_lag_2 = c(44541.4,67701.6,74691.6,131524),gdp_p62_lag_2 = c(19504.2,24888.9,28063.4,55885.5),value_be_lag_2 = c(40076.7,46109.4,47967.1,71152.6),value_c_lag_2 = c(32955.4,38908.4,40192.9,60962.8),value_j_lag_2 = c(5576.8,6313.9,7737.1,11455.3),value_k_lag_2 = c(9191,10458,10225.2,13236.4),value_mn_lag_2 = c(10092,12942.5,15074,30737.7
    ),value_oq_lag_2 = c(30224.3,33251.5,35065.6,55926.4),value_total_lag_2 = c(167141.8,190624.9,202353.5,318952.7),berd = c(2146.085,3130.884,3556.479,4207.669,4448.676,4845.861,5232.63,5092.902,5520.422,5692.841,6540.457,6778.42,7324.679,7498.488,7824.51,7888.444,8461.72)),row.names = c(NA,-17L),class = c("tbl_df","tbl","data.frame"))





set.seed(1234)  ## to make the results reproducible






## I need a particular custom split of my dataset: the test set consists of only the most recent observation,whereas all the rest is the training set

## see https://github.com/tidymodels/rsample/issues/158


indices <-
  list(analysis   = seq(nrow(df_ini)-1),assessment = nrow(df_ini)
       )

df_split <- make_splits(indices,df_ini)


## df_split <- initial_split(df_ini) ## with the default splitting,## ## the code works

df_train <- training(df_split)
df_test <- testing(df_split)

folded_data <- vfold_cv(df_train,3)



glmnet_recipe <- 
    recipe(formula = berd ~ .,data = df_train) %>%
    update_role(year,new_role = "ID") %>%
  step_zv(all_predictors()) %>% 
  step_normalize(all_predictors(),-all_nominal()) 

glmnet_spec <- 
  linear_reg(penalty = tune(),mixture = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("glmnet") 

glmnet_workflow <- 
  workflow() %>% 
  add_recipe(glmnet_recipe) %>% 
  add_model(glmnet_spec) 




glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6,-1,length.out = 20),mixture = c(0.05,0.2,0.6,1)) 

glmnet_tune <- 
  tune_grid(glmnet_workflow,resamples = folded_data,grid = glmnet_grid,control = control_grid(save_pred = TRUE) ) 

print(collect_metrics(glmnet_tune))
#> # A tibble: 240 x 8
#>       penalty mixture .metric .estimator    mean     n std_err .config 
#>         <dbl>   <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>   
#>  1 0.000001      0.05 rmse    standard   375.        3 48.9    Model001
#>  2 0.000001      0.05 rsq     standard     0.929     3  0.0420 Model001
#>  3 0.00000183    0.05 rmse    standard   375.        3 48.9    Model002
#>  4 0.00000183    0.05 rsq     standard     0.929     3  0.0420 Model002
#>  5 0.00000336    0.05 rmse    standard   375.        3 48.9    Model003
#>  6 0.00000336    0.05 rsq     standard     0.929     3  0.0420 Model003
#>  7 0.00000616    0.05 rmse    standard   375.        3 48.9    Model004
#>  8 0.00000616    0.05 rsq     standard     0.929     3  0.0420 Model004
#>  9 0.0000113     0.05 rmse    standard   375.        3 48.9    Model005
#> 10 0.0000113     0.05 rsq     standard     0.929     3  0.0420 Model005
#> # … with 230 more rows

print(show_best(glmnet_tune,"rmse"))
#> # A tibble: 5 x 8
#>      penalty mixture .metric .estimator  mean     n std_err .config 
#>        <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>   
#> 1 0.000001      0.05 rmse    standard    375.     3    48.9 Model001
#> 2 0.00000183    0.05 rmse    standard    375.     3    48.9 Model002
#> 3 0.00000336    0.05 rmse    standard    375.     3    48.9 Model003
#> 4 0.00000616    0.05 rmse    standard    375.     3    48.9 Model004
#> 5 0.0000113     0.05 rmse    standard    375.     3    48.9 Model005

best_net <- select_best(glmnet_tune,"rmse")


final_net <- finalize_workflow(
  glmnet_workflow,best_net
)


final_res_net <- last_fit(final_net,df_split)
#> x : internal: Error in data.frame(...,check.names = FALSE): arguments imply...
#> Warning: All models Failed in [fit_resamples()]. See the `.notes` column.


print(final_res_net)
#> Warning: This tuning result has notes. Example notes on model fitting include:
#> internal: Error in data.frame(...,check.names = FALSE): arguments imply differing number of rows: 2,0
#> # resampling results
#> # Monte Carlo cross-validation (0.94/0.059) with 1 resamples  
#> # A tibble: 1 x 5
#>   splits         id               .metrics .notes           .predictions
#>   <list>         <chr>            <list>   <list>           <list>      
#> 1 <split [16/1]> train/test split <NULL>   <tibble [1 × 1]> <NULL>

final_fit <- final_res_net %>%
    collect_predictions()

reprex package(v0.3.0.9001)于2020-10-15创建

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其他元素将获得点击?
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。)
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbcDriver发生异常。为什么?
这是用Java进行XML解析的最佳库。
Java的PriorityQueue的内置迭代器不会以任何特定顺序遍历数据结构。为什么?
如何在Java中聆听按键时移动图像。
Java“Program to an interface”。这是什么意思?
Java在半透明框架/面板/组件上重新绘画。
Java“ Class.forName()”和“ Class.forName()。newInstance()”之间有什么区别?
在此环境中不提供编译器。也许是在JRE而不是JDK上运行?
Java用相同的方法在一个类中实现两个接口。哪种接口方法被覆盖?
Java 什么是Runtime.getRuntime()。totalMemory()和freeMemory()?
java.library.path中的java.lang.UnsatisfiedLinkError否*****。dll
JavaFX“位置是必需的。” 即使在同一包装中
Java 导入两个具有相同名称的类。怎么处理?
Java 是否应该在HttpServletResponse.getOutputStream()/。getWriter()上调用.close()?
Java RegEx元字符(。)和普通点?