微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

在测试集上训练模型后获取 RMSE 和 R2 时出错

如何解决在测试集上训练模型后获取 RMSE 和 R2 时出错

我有一个训练数据 (train.dat) 和测试数据 (test.dat)。我想在训练数据上训练我的 LASSO 模型后在测试数据上运行它,这似乎没问题。

从那里,我想得到 RMSE 和 R2 来观察模型的预测准确性。但是,我得到了错误pred - obs 中的错误:二元运算符的非数字参数(对于 RMSE)和 完整.cases(pred) 中的错误:并非所有参数都相同R2 的长度。

谁能告诉我我的代码出了什么问题?

library(caret)

tr.Control <- trainControl(method = "repeatedcv",number = 10,repeats = 5,classprobs = FALSE,)
set.seed(10345678)
lasso.fit2 <- train(Lifeexp ~ .,data = train.dat,trControl = tr.Control,method = "glmnet",preProcess = c("center","scale"),tuneGrid =
            expand.grid(alpha = 1,lambda = seq(10^-6,1,length.out = 100)) )

lasso.pred <- predict(lasso.fit2,newdata = test.dat,type = "raw")

RMSE(lasso.fit2,test.dat$Lifeexp)
R2(bag.rf.fit2,test.dat$Lifeexp)

Train.dat:

structure(list(GDP = c(402.1030419,442.2030419,543.3030419,520.8966027,254.2432569,124.4608003,341.5541149,772.3135303,478.6685897,191.8789042,592.4010975,1033.912431,138.4288795,622.4988457,642.7767443,317.3893069,269.8711377,709.5819646,585.07655,780.190201,3122.362815,3893.596078,1166.610276,1674.825261,3690.113268,4241.788782,2441.741991,4043.662051,9040.566251,963.8417858,2234.579866,10330.61561,1944.137621,2136.440243,567.5286729,567.930736,2292.445156,2028.18197,371.6785662,519.5343268,987.409723,1482.403063,1196.586858,1955.588006,6941.235848,1038.90854,3102.713363,3139.966054,3032.427138,7328.615629,869.6965166,2799.648876,617.2304355,1126.683318,4094.362119,7708.100996,10385.96443,11683.94962,718.1878292,3243.231125,3100.280468,11286.24302,8920.762105,201.4671636,785.5022829,1510.324871,1831.001912,8141.913127,12027.36588,6967.24523,7691.345097,3233.295943,367.5566093,1357.563719,1489.876911,977.2736357,1508.942737,2007.736363,5076.342992,7273.563207,948.3318545,2146.996385,95.18825018,390.0933261,2566.59695,52022.1256,57373.68668,19095.467,28149.87001,39435.8399,20600.37525,23041.53473,44141.87814,47518.63604,24190.24962,46232.98962,26891.44645,61350.34791,28364.64508,50152.34014,22303.96133,23635.92922,41531.9342,47603.02763,9600.18513,12042.95373,26917.75898,20324.25356,20087.59199,36000.52012,25423.07201,32018.06325,43024.92384,73191.11632,12663.36453,30693.59308,18440.37852,38577.38166,33994.40657,21290.86038,50950.03434,53024.05921,13663.02162,13641.10272,41945.33167,1731.209509,4492.727604,11861.75616,47236.96023,23509.54339,26123.97387,74605.77451),Health = c(22.23474948,36.44474948,45.58774948,46.38774948,3.333203815,5.359203815,16.69390488,19.46990488,33.22835541,5.300580788,29.97179604,33.59179604,5.971383095,62.66848373,67.22848373,8.23568,14.98141193,32.6487999,10.22661548,16.19961548,92.18703461,98.65987461,143.7665911,159.7515106,308.6578979,402.5568979,99.5689502,111.4155502,292.8907166,198.2263198,221.1403198,705.336568,176.6524443,200.7054443,12.56211728,17.72411728,76.7208786,98.4562786,9.55682529,16.01162529,26.5686245,33.565445,69.66563616,89.45643616,275.2236792,32.77552414,122.5689168,198.7124574,221.7829742,539.567627,43.70681763,108.6149597,33.2254878,42.36598,60.2569,705.1993408,891.1377563,992.5689563,31.84200096,77.2356478,277.45864,891.7641602,932.325129,15.23564,54.30473709,74.231488,200.564125,665.2514038,755.36985,384.9183044,445.20158,262.5267029,11.56898,45.25077438,109.0749969,122.02145,42.568412,62.25963211,172.0576935,200.562134,91.17743683,120.236549,11.23587,18.82835197,99.23568,4952.777344,5236.3654,1101.36589,1674.2854,3309.480957,1654.5687,1845.321045,4449.542969,5000.36545,1998.634277,6054.23658,1900.2356,7025.36987,1000.5689,5036.2356,1233.36545,2334.651855,4597.244629,5698.2547,1500.3698,2000.23564,2573.740234,3002.36547,1520.453613,3214.546387,1569.3254,2873.848145,3644.802734,4587.235478,1122.02145,2211.019043,462.5890808,1061.365601,1256.56897,1987.2145,5186.632813,6547.2356,990.32658,1053.891602,4201.3698,238.0044861,712.2356,1513.565918,2015.18042,2985.23,8021.80957),Govthealth = c(1.25689,2.032658,2.495758057,2.965478,1.985478,2.209019899,2.882325411,3.21458,7.3134408,1.032568,5.433434963,7.235478,1.239725351,8.535984039,10.323589,1.236589,3.562868595,4.673761368,2.32547,4.648055553,23.70949936,33.235687,51025478,71.8605423,205.9026794,295.2356,31.2587,51.99817276,154.70401,56.32588,73.30036926,399.23568,66.3265,99.82849121,2.23568,3.246135235,10.43734169,15.235478,3.569877,5.623521328,5.849419594,8.32665,35.3654457,44.96020508,195.3657,14.55177689,35.235698,61.02356,81.59127045,284.7705994,23.43979454,43.92045593,22.36587,30.42416763,181.3415375,385.9675598,576.0806274,602.3258,25.36730576,66.235687,92.2147,401.4833984,502.3698,2.0214578,10.70767879,15.36987,112.3698,481.0765686,502.36987,226.7909851,300.65478,55.95266342,2.36547,11.85855961,35.50076675,45.235698,25.36954,34.36005783,126.9312592,156.3257,23.53768349,39.235687,4.235687,6.570708275,45.36987,3399.406006,4500.321547,990.36547,1368.160278,2804.857178,1000.365,1375.334717,3458.573975,4120.325,1456.037842,4100.368,1500.36578,6925.325445,990.58795,4125.25658,998.25998,1827.566895,3482.541016,4800.3256,989.325,1254.325,1756.99939,1998.23569,1104.429321,2521.927002,1800.3256,2315.543701,2931.431641,331.0256,548.32,1388.55896,351.3133545,898.4367065,997.02145,956.32547,3488.651855,4400.23556,558.36987,785.0509033,3000.3658,100.36987,162.3498688,162.365,543.0645752,1458.283813,2000.3694,2495.23877),Privhealth = c(14.3698,25.36698,36.01279831,49.36875,1.23569,2.278559208,8.061329842,10.3658,5.059076786,3.25698,20.38587761,30.65877,4.726452827,22.79703331,32.65878,6.32589,10.38636589,19.33849907,8.326589,11.07592678,67.27728271,74.23658,63.235698,83.74517059,88.83229828,96.32568,49.32658,59.41738892,138.1631165,100.23564,147.8399658,300.23568,71.02584,90.6206665,8.365984,11.47062778,61.48280716,74.254785,7.235647,10.26313496,19.40570831,23.65879,33.25478,44.17641068,189.32658,17.06592751,75.325689,89.32658,136.7345276,238.6507721,19.86775017,63.43461227,7.325478,19.23568,25.321547,319.0157471,311.9694214,442.03695,3.889117956,15.3654,115.02365,488.0875244,552.0325698,36.04922485,45.362154,45.23548,182.7733917,202.3654,142.2067719,202.325,197.0276337,9.32658,32.95304871,70.28269196,90.3256,15.021457,27.89465141,44.9021492,60.32568,43.03323364,60.325845,8.325698,11.45799065,1553.358765,2330.2354,201.0214578,305.5347595,503.7982178,301.23565,469.9864197,990.9689331,1200.36987,542.5964966,1823.021457,312.0215478,1100.32145,301.02145,1100.3256,320.365478,507.0849609,1114.720093,2001.23548,401.14567,662.03214,816.2644653,998.32546,416.0243225,692.6192017,402.32564,558.3044434,713.3709106,998.32658,302.0214,793.8995972,111.2757187,162.9289398,212.3657,442.32598,1698.060913,2226.32568,145.2365,268.8859863,902.32568,42.36587,75.64861298,332.65478,970.5014648,556.8964233,700.32658,5526.447266),Population = c(12412308L,20779953L,29185507L,37172386L,47887865L,66224804L,87639964L,109224559L,14539612L,18905478L,27013212L,28087871L,6216341L,32428167L,42723139L,8449913L,10946445L,15049353L,181413402L,211513823L,241834215L,267663435L,3565890L,5122493L,7261539L,9956011L,18029824L,23194257L,28208035L,223158L,279398L,515696L,1432905L,1794571L,95212450L,122283850L,158503197L,195874740L,107647921L,142343578L,179424641L,212215030L,22071433L,26459944L,31989256L,77991755L,106651922L,36800509L,44967708L,51216964L,18777601L,20261737L,3286542L,3089027L,2913021L,36870787L,40788453L,44494502L,591021L,754394L,149003223L,195713635L,209469333L,8975597L,14312212L,16249798L,3119433L,4577378L,4999441L,70878L,71625L,3786695L,873277798L,1234281170L,34545013L,41801533L,56558186L,62952642L,67195028L,69428524L,12697723L,14439018L,67988862L,79910412L,95540395L,22031750L,24982688L,57247586L,58892514L,62766365L,9967379L,10251250L,10895586L,11433256L,30685730L,37057765L,5140939L,5793636L,4986431L,5515525L,79433029L,82211508L,81776930L,82905782L,10196792L,10805808L,11121341L,10731726L,56942108L,59277417L,254826L,281205L,318041L,352721L,4660000L,7623600L,2045123L,2991884L,4137309L,14951510L,16615394L,17231624L,3329800L,3857700L,4841000L,38110782L,38258629L,3047132L,5076732L,2048583L,2073894L,7824909L),Lifeexp = c(50.331,55.841,61.028,64.486,47.099,51.941,61.627,66.24,55.564,54.404,67.611,70.478,61.974,57.099,62.973,45.746,48.069,55.251,62.32,65.772,69.205,71.509,69.872,71.73,73.428,74.405,70.865,72.594,74.493,61.529,70.173,78.627,61.608,52.192,45.9,46.267,50.896,54.332,60.1,62.82,65.264,67.114,66.165,71.111,76.516,68.793,71.095,63.307,56.048,57.669,71.333,75.439,71.836,73.955,76.562,73.576,75.278,76.52,60.884,71.46,66.343,73.619,75.672,53.595,66.56,69.57,75.654,78.769,80.095,74.619,77.672,57.865,66.693,62.764,65.095,70.248,70.623,74.184,76.931,50.64,61.195,70.551,73.025,75.317,81.69512195,82.74878049,75.8804878,77.74146341,80.40243902,76.05195122,77.72195122,80.18292683,81.59512195,79.13658537,81.94878049,74.80536585,81.35121951,74.81317073,81.83414634,75.2277561,77.92682927,79.98780488,80.99268293,76.93902439,77.88780488,80.38780488,81.28780488,79.77804878,82.03658537,78.03634146,79.65365854,81.89756098,82.66097561,76.60731707,81.60243902,73.142,74.358,75.398,76.87804878,80.70243902,81.76097561,75.37804878,78.63658537,81.85853659,70.8902439,73.74878049,75.29512195,81.54146341,79.42195122,81.02926829,82.24634146),Govted = c(1.23568,2.31245,3.47945,5.32658,2.365,3.98311,4.49659,6.32547,3.5398,1.023568,3.63172,5.16365,2.32871,2.38901,2.52076,1.23568,2.97156,3.34389,0.984578,1.36589,2.81228,4.326587,1.2365897,1.9654789,2.3658,3.58851,3.23568,5.97161,4.96645,3.21548,2.32657,6.99139,1.32658,2.012457,3.214587,2.51681,1.83782,2.28687,3.9854587,2.36587,3.22803,3.71993,3.26766,5.32568,5.12579,5.44358,5.72174,2.36578,1.71774,2.3265,3.43017,2.65897,4.58031,5.01971,6.32658,5.51379,6.64043,5.6488,1.235687,1.53379,2.16286,3.24578,6.63445,7.02824,3.325478,3.215487,3.37769,3.23654,3.323568,5.25346,3.50844,1.54406,4.60449,3.326589,4.235478,4.17277,5.55006,6.32365,4.05552,4.06533,5.74164,4.021547,6.40799,6.9874564,5.442,7.32658,8.9854587,5.33591,3.0215478,3.21547,4.91368,6.3265,2.04608,3.23019,4.32658,5.023658,4.29886,4.35239,4.25224,6.44717,6.97848,7.235689,5.43073,5.54157,2.985467,3.124578,3.32652,5.22879,5.48909,4.236587,5.321457,6.323658,7.5698745,3.26587,4.9936,2.325647,3.08044,5.56251,5.965871,4.92605)),row.names = c(1L,2L,3L,4L,5L,6L,7L,8L,11L,13L,15L,16L,18L,23L,24L,25L,26L,27L,29L,30L,31L,32L,33L,34L,35L,36L,37L,38L,39L,41L,42L,44L,45L,46L,49L,50L,51L,52L,53L,54L,55L,56L,57L,58L,60L,62L,64L,65L,66L,67L,70L,71L,73L,74L,75L,78L,79L,80L,82L,84L,85L,87L,88L,89L,91L,92L,93L,95L,96L,99L,100L,103L,105L,107L,111L,112L,113L,114L,115L,116L,119L,120L,121L,122L,124L,127L,128L,129L,130L,131L,133L,134L,135L,136L,138L,140L,141L,144L,145L,148L,149L,150L,151L,152L,153L,154L,155L,156L,158L,159L,161L,162L,163L,164L,165L,167L,170L,171L,172L,173L,175L,176L,177L,178L,180L,181L,182L,185L,187L,191L,192L,195L),class = "data.frame")

测试数据:

structure(list(GDP = c(199.9863423,156.3857186,389.3980332,229.4902871,497.6320261,749.552711,826.6215305,248.0293672,261.8689977,899.6599081,11373.233,7076.662423,5324.61704,5931.453886,5082.354757,715.9137121,2124.05677,6374.028196,463.6186318,4102.48135,5268.848504,4333.482973,564.7796095,2258.183141,3749.75325,302.5771636,3772.870012,2860.43156,4787.780171,1614.640122,749.9085236,4717.143026,443.3141934,2009.978857,483.952592,366.1728076,841.9729898,563.0577411,1317.890706,18211.27459,21679.24784,42943.90227,21448.36196,47450.31847,30743.54768,58041.39844,24285.46682,46459.97325,20825.78421,34483.204,21043.57493,41715.02928,8794.631229,26149.41108,33692.01083,12599.53358,15420.91116,23852.32703,64581.94402,9107.477079,10201.30354,38428.3855,37868.296,82796.54716),Health = c(6.22435541,8.909747124,39.22274712,8.625580788,4.22284155,42.34384155,47.44484155,10.74555809,18.80055809,45.32365,324.6654166,602.659668,504.5536499,594.8854499,239.3392792,22.55662414,91.84031677,624.335527,30.56891763,128.3355597,74.23569,505.4589408,22.23569,69.80043793,311.6526794,19.73552704,251.0935822,211.589745,250.7455292,35.25698,47.90106964,292.54782,18.56432343,70.5685123,10.56888,17.38329887,50.66987,75.201547,78.18682861,1022.5487,1632.427612,4002.325,1452.369,5044.135254,2496.047119,6011.536621,1655.866211,4099.587891,1125.365,4400.325,1496.87854,3000.23568,336.2356,2023.143677,3216.223633,809.1994019,956.21547,820.6981812,1989.235,446.3265,796.6470337,2985.12,3737.802979,9658.23
),Govthealth = c(2.65987,3.350677967,8.32365,1.337858081,0.235689,8.714180946,11.02365,2.356894,4.656533241,5.958777,198.23568,319.1759033,207.0215302,302.654789,123.2336197,29.2992878,300.5689,12.02589,52.658912,22.03256,222.325689,16.3258,50.29269791,129.758316,3.900079966,163.0175018,102.369,156.8104706,4.36987,5.465222836,75.36987,3.839128733,14.32589,3.25478,5.880064487,12.36547,18.02584,30.97570801,990.365478,1116.231445,3201.0245,996.598723,3721.796387,2074.39917,5042.459961,1229.708252,3167.418213,889.32658,3698.23598,944.5585938,1998.02365,200.365778,1396.733398,2517.370117,577.3640747,662.32589,298.1834717,702.369,456.325,568.7339478,889.36547,1045.900513,3987.3654),Privhealth = c(1.36589,1.832908154,7.325698,5.431494236,2.36589,29.85413742,35.3698,4.23568,8.9836483,22.3658,152.36589,263.3545532,225.5363922,301.325478,111.575592,10.23568,60.89479446,336.02145,12.36587,34.3265,223.02145,2.0215478,11.81901455,180.9026947,15.41190529,85.28456879,45.321478,86.49634552,25.36987,39.00668716,220.32145,14.22738075,49.326545,7.02145,11.50323391,20.36587,33.021456,45.45627975,400.23568,516.1798096,NA,400.32547,1322.338745,421.6481018,969.076416,426.0691833,931.8737793,302.1245,886.02154,517.4750366,889.32547,626.4102173,698.8658447,231.8352966,301.0324,522.5147705,1236.021458,117.3658,227.9130707,1965.3256,2691.985107,6600.3256),Population = c(9404500L,11148758L,18143315L,23941110L,5283814L,7527394L,9100837L,17354392L,23650172L,19077690L,31528585L,365734L,2118874L,2448255L,29027674L,61895160L,93966780L,57779622L,17325773L,21670000L,2866376L,32618651L,530804L,685503L,174790340L,12155239L,3962372L,70419L,69650L,4802000L,4077131L,3726549L,1056575549L,1352617328L,20147590L,27275015L,10432421L,11881477L,87967651L,17065100L,19153000L,66460344L,27691138L,34004889L,5339616L,5547683L,5176209L,5363352L,56719240L,60421760L,6289000L,8882800L,2095344L,15925513L,4350700L,38042794L,37974750L,4027887L,5638676L,1998161L,1988925L,6715519L,7184250L,8513227L),Lifeexp = c(46.096,45.09,63.798,62.288,58.824,68.736,70.879,45.853,46.229,58.893,75.997,75.905,56.665,63.373,74.41,66.366,69.823,63.857,69.509,76.812,78.458,71.594,52.878,68.384,70.116,58.432,77.452,66.843,71.116,70.386,69.902,73.6,62.505,69.416,55.5,58.472,58.1,44.649,74.837,76.99463415,79.23414634,81.35609756,77.42195122,81.24634146,76.59268293,79.1,77.46585366,79.87073171,76.97073171,82.94634146,78.95365854,82.80243902,72.15,77.98780488,76.24634146,77.75365854,77.95121951,83.14634146,73.20487805,75.41219512,77.24243902,79.6804878,83.55121951),Govted = c(3.27054,5.24797,4.71484,2.97515,1.36587,4.00675,1.023658,2.46167,4.53477,4.11747,8.34961,10.23547,2.8673,5.326545,6.15899,2.41093,2.11189,2.46866,1.06738,4.02447,3.94893,1.65599,4.68696,1.856231,2.032145,1.56897,2.18109,4.32479,5.326587,0.36589,1.01218,1.45426,5.13722,4.6764,4.89147,7.3265,5.99199,5.36993,8.08434,8.55955,5.71688,6.54071,3.325687,6.12262,1.326587,4.58512,7.00241,5.06843,3.3213,3.32365,4.32657,4.52294,4.7814,5.9658745
)),row.names = c(9L,10L,12L,14L,17L,19L,20L,21L,22L,28L,40L,43L,47L,48L,59L,61L,63L,68L,69L,72L,76L,77L,81L,83L,86L,90L,94L,97L,98L,101L,102L,104L,106L,108L,109L,110L,117L,118L,123L,125L,126L,132L,137L,139L,142L,143L,146L,147L,157L,160L,166L,168L,169L,174L,179L,183L,184L,186L,188L,189L,190L,193L,194L,196L),class = "data.frame")

解决方法

您的测试数据集中有 NA 值,您可以使用以下命令避免错误:lasso.pred <- predict(lasso.fit2,newdata = test.dat,na.action = na.pass,type="raw")

,

您在调用 R2RMSE 函数时犯了一个错误。除了@Bastien Ducreux 提供的建议外,您还可以使用以下代码

lasso.pred <- predict(lasso.fit2,type="raw")

RMSE(lasso.pred,test.dat$Lifeexp)
R2(lasso.pred,test.dat$Lifeexp)

在您的问题中,您调用的是模型本身 (lasso.fit2)。这就是为什么您收到以下错误

pred - obs 中的错误:二元运算符的非数字参数

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。