如何解决Tidymodels:具有自定义数据拆分的奇怪错误消息
我正在学习新的tidymodels框架的绳索,因此我可能会误解一些基本知识。
我提供了一个独立的示例,其中包含真实的数据集(从我的工作中得出)。 鉴于我需要使用除最近的观测值以外的所有观测值作为训练集,而仅将最近的观测值用作测试集(因此在这种情况下,测试集仅是观测值)。
谢谢!
library(tidyverse)
library(tidymodels)
df_ini <- structure(list(year = c(1998,2002,2004,2005,2006,2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018),capital_n1132g_lag_1 = c(3446.5,4091.1,3655.1,3633.3,3616.2,3450.7,3596.8,3867.2,3372.5,3722.9,3808.5,4005.6,3718.6,3467.9,4214.2,4237.4,4450.2),capital_n117g_lag_1 = c(4920.9,7810.6,8560.3,8679.9,8938.9,9823.8,10467.1,11047.1,11554.3,11849.9,13465.4,13927.5,15510.2,15754.4,16584.7,17647.1,18273.8),capital_n11mg_lag_1 = c(16846,19605,19381.2,19433.5,20051.6,20569.8,22646.1,23674.5,21200.6,20919.6,23157.7,23520.7,24057.7,23832.8,25019.2,27608.2,29790.1),employment_be_lag_1 = c(2834.42,2839.72,2765.53,2731.08,2709.59,2708.39,2774.06,2795.6,2703.36,2668.1,2705.1,2731.67,2727.16,2725.66,2735.69,2750.52,2782.9
),employment_c_lag_1 = c(2612.76,2623.69,2552.89,2518.57,2496.98,2499.54,2558.88,2578,2483.97,2447.65,2483.1,2507.41,2500.94,2499.6,2511.75,2523.97,2555.48),employment_j_lag_1 = c(292.93,389.2,389.45,387.53,384.64,389.29,385.77,392.86,383.91,392.18,410.85,419.75,427.59,438.96,440.33,460.84,473.4
),employment_k_lag_1 = c(505.33,507.12,510.25,504.63,515.39,523.45,536.6,550.14,546.68,539.96,536.58,534.98,524.13,518.89,511.57,505.32,496.41),employment_mn_lag_1 = c(945.59,1217.96,1289.55,1365.29,1425.81,1537.88,1622.95,1727.76,1704.65,1762.55,1838.16,1896.09,1929.09,1950.02,1968.83,2021.51,2109.71),employment_oq_lag_1 = c(3065.87,3191.75,3280.36,3317.09,3401.65,3476.63,3508.01,3577.75,3683.85,3759.23,3798.35,3850.17,3877.24,3924.06,4002.74,4095.59,4171.72),employment_total_lag_1 = c(14509.58,15127.99,15212.11,15307.28,15491.61,15762.92,16050.92,16356.53,16269.97,16392.87,16647.79,16820.66,16879.06,17039.6,17142.13,17365.32,17650.21),gdp_b1gq_lag_1 = c(187849.7,220525,231862.5,242348.3,254075,267824.4,283978,293761.9,288044.1,295896.6,310128.6,318653.1,323910.2,333146.1,344269.3,357608,369341.3),gdp_p3_lag_1 = c(139695.2,161175.8,169405.6,176316.4,185871.1,194102,200944.4,208857.1,213630.1,218947.2,227250.8,233638.1,238329.3,243860.6,249404.3,257166.5,265900.2),gdp_p61_lag_1 = c(50117.6,71948.6,74346.9,83074.9,90010.4,100076.8,110157.2,113368.1,91435.3,111997.3,123526.3,125801.2,123657.1,126109.3,129183.6,131524,140057.8),gdp_p62_lag_1 = c(19441,26444.4,28995.1,30507,33520.2,36089.5,39104,43056.8,38781.9,39685.8,43784.1,46187.6,49444.7,51746,53585.8,55885.5,59584.7),price_index_lag_1 = c(1.2,2.3,1.3,2,2.1,1.7,2.2,3.2,0.4,3.6,2.6,1.5,0.8,1,2.2),value_be_lag_1 = c(40533.1,48207.1,48673.2,50737.6,52955.2,56872.4,60864.9,61029,56837.8,58433.6,61443,63655.1,64132.3,65542.6,67495.4,71152.6,72698.8),value_c_lag_1 = c(33441.8,40446.6,40467.4,42014.6,44229,47735.5,51552.4,51165.9,47129.7,48759.3,51467.7,53234.6,53431.4,55169,57458.7,60962.8,62196
),value_j_lag_1 = c(5483.7,7326.1,7934.1,7756.1,8134.2,8378.8,8532.3,8740,8493.9,8518.9,9217.1,9405.1,9802.1,10361.4,10695.4,11455.3,11720.6),value_k_lag_1 = c(9210.6,9977.3,10146.9,10541.9,11005.3,11912.3,13102.7,13205.2,12123.9,12113.2,12952.8,12254.9,12796.6,12962.4,13482.9,13236.4,13744.1),value_mn_lag_1 = c(10444,14061.4,15706.6,16569.1,18008.7,19576.6,21317,23189.8,22490,23255.2,24895.4,25988.7,26998.2,28027.3,29207.9,30737.7,32259.6
),value_oq_lag_1 = c(29902.7,34179.2,36126.8,37329.6,38288.8,40003.1,41511.4,43761.3,45817.8,46996.6,47980.9,49381.5,50261.7,51624.3,53715,55926.4,57637.1),value_total_lag_1 = c(167323.4,197076.7,207247.6,216098.3,225888.1,239076,253604.6,262414.7,256671,263633.5,276404,283548.2,288624.3,297230.1,307037.7,318952.7,329396.1),capital_n1132g_lag_2 = c(3599.2,3996.9,3638.4,4237.4
),capital_n117g_lag_2 = c(4636.2,7008.5,8369.6,17647.1),capital_n11mg_lag_2 = c(17181.5,19677.8,18749.6,27608.2),employment_be_lag_2 = c(2870.33,2840.19,2775.22,2750.52),employment_c_lag_2 = c(2626.2,2621.08,2562.53,2523.97
),employment_j_lag_2 = c(275.08,374.56,400.75,460.84),employment_k_lag_2 = c(500.9,505.13,502.42,505.32
),employment_mn_lag_2 = c(904.38,1143.78,1248.01,2021.51),employment_oq_lag_2 = c(3028.85,3162.77,3241.36,4095.59),employment_total_lag_2 = c(14404.29,15019.87,15113.52,17365.32),gdp_b1gq_lag_2 = c(186928.7,213606.4,226735.3,357608),gdp_p3_lag_2 = c(140335.8,156117.3,164107.8,257166.5),gdp_p61_lag_2 = c(44541.4,67701.6,74691.6,131524),gdp_p62_lag_2 = c(19504.2,24888.9,28063.4,55885.5),value_be_lag_2 = c(40076.7,46109.4,47967.1,71152.6),value_c_lag_2 = c(32955.4,38908.4,40192.9,60962.8),value_j_lag_2 = c(5576.8,6313.9,7737.1,11455.3),value_k_lag_2 = c(9191,10458,10225.2,13236.4),value_mn_lag_2 = c(10092,12942.5,15074,30737.7
),value_oq_lag_2 = c(30224.3,33251.5,35065.6,55926.4),value_total_lag_2 = c(167141.8,190624.9,202353.5,318952.7),berd = c(2146.085,3130.884,3556.479,4207.669,4448.676,4845.861,5232.63,5092.902,5520.422,5692.841,6540.457,6778.42,7324.679,7498.488,7824.51,7888.444,8461.72)),row.names = c(NA,-17L),class = c("tbl_df","tbl","data.frame"))
set.seed(1234) ## to make the results reproducible
## I need a particular custom split of my dataset: the test set consists of only the most recent observation,whereas all the rest is the training set
## see https://github.com/tidymodels/rsample/issues/158
indices <-
list(analysis = seq(nrow(df_ini)-1),assessment = nrow(df_ini)
)
df_split <- make_splits(indices,df_ini)
## df_split <- initial_split(df_ini) ## with the default splitting,## ## the code works
df_train <- training(df_split)
df_test <- testing(df_split)
folded_data <- vfold_cv(df_train,3)
glmnet_recipe <-
recipe(formula = berd ~ .,data = df_train) %>%
update_role(year,new_role = "ID") %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors(),-all_nominal())
glmnet_spec <-
linear_reg(penalty = tune(),mixture = tune()) %>%
set_mode("regression") %>%
set_engine("glmnet")
glmnet_workflow <-
workflow() %>%
add_recipe(glmnet_recipe) %>%
add_model(glmnet_spec)
glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6,-1,length.out = 20),mixture = c(0.05,0.2,0.6,1))
glmnet_tune <-
tune_grid(glmnet_workflow,resamples = folded_data,grid = glmnet_grid,control = control_grid(save_pred = TRUE) )
print(collect_metrics(glmnet_tune))
#> # A tibble: 240 x 8
#> penalty mixture .metric .estimator mean n std_err .config
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.000001 0.05 rmse standard 375. 3 48.9 Model001
#> 2 0.000001 0.05 rsq standard 0.929 3 0.0420 Model001
#> 3 0.00000183 0.05 rmse standard 375. 3 48.9 Model002
#> 4 0.00000183 0.05 rsq standard 0.929 3 0.0420 Model002
#> 5 0.00000336 0.05 rmse standard 375. 3 48.9 Model003
#> 6 0.00000336 0.05 rsq standard 0.929 3 0.0420 Model003
#> 7 0.00000616 0.05 rmse standard 375. 3 48.9 Model004
#> 8 0.00000616 0.05 rsq standard 0.929 3 0.0420 Model004
#> 9 0.0000113 0.05 rmse standard 375. 3 48.9 Model005
#> 10 0.0000113 0.05 rsq standard 0.929 3 0.0420 Model005
#> # … with 230 more rows
print(show_best(glmnet_tune,"rmse"))
#> # A tibble: 5 x 8
#> penalty mixture .metric .estimator mean n std_err .config
#> <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.000001 0.05 rmse standard 375. 3 48.9 Model001
#> 2 0.00000183 0.05 rmse standard 375. 3 48.9 Model002
#> 3 0.00000336 0.05 rmse standard 375. 3 48.9 Model003
#> 4 0.00000616 0.05 rmse standard 375. 3 48.9 Model004
#> 5 0.0000113 0.05 rmse standard 375. 3 48.9 Model005
best_net <- select_best(glmnet_tune,"rmse")
final_net <- finalize_workflow(
glmnet_workflow,best_net
)
final_res_net <- last_fit(final_net,df_split)
#> x : internal: Error in data.frame(...,check.names = FALSE): arguments imply...
#> Warning: All models Failed in [fit_resamples()]. See the `.notes` column.
print(final_res_net)
#> Warning: This tuning result has notes. Example notes on model fitting include:
#> internal: Error in data.frame(...,check.names = FALSE): arguments imply differing number of rows: 2,0
#> # resampling results
#> # Monte Carlo cross-validation (0.94/0.059) with 1 resamples
#> # A tibble: 1 x 5
#> splits id .metrics .notes .predictions
#> <list> <chr> <list> <list> <list>
#> 1 <split [16/1]> train/test split <NULL> <tibble [1 × 1]> <NULL>
final_fit <- final_res_net %>%
collect_predictions()
由reprex package(v0.3.0.9001)于2020-10-15创建
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。