微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

加速data.table的子集化和数千个回归的实现

如何解决加速data.table的子集化和数千个回归的实现

我有一个 data.table,有 100 行和 3 列。这些行被分为 30 组。三列是我的自变量。

在每次迭代期间,我从每组中随机选取一行并创建一个包含 30 行的子集。

然后我将子集连接到另一个包含我的因变量的 data.table

有数千种可能的组合。我尝试使用 foreach 加速代码,如下所示。到目前为止,我已经尝试了 1000 次迭代,它似乎有所帮助,但由于我将不得不执行更多的数千个组合,我想知道是否有办法提高效率或速度。

library(parallel)
library(foreach)
library(doParallel)

#data.table containing all independent values
ids <- vector()
#my experiment results in multiple rows per group. Creating such repetitive  
#group ids was surprisingly not very straight forward 
for(i in 1:100){ids[i] <- sample(1:30,1)}
ids <- sort(ids)
x1 <- rnorm(100)
x2 <- rnorm(100)
x3 <- rnorm(100)
dd1 <- data.table(ids,x1,x2,x3)

#data.table containing all dependent values
ids <- seq(1:30)
y <- rnorm(30)
dd2 <- data.table(ids,y)

clus <- makeCluster(detectCores() - 1)
registerDoParallel(clus,cores = detectCores() - 1)


out <- foreach(i = 1:1000,.packages=c("dplyr","data.table","caret"),.combine='c') %dopar% {
  dd3 <- dd1[,.SD[sample(.N,min(1,.N))],by = ids]
  dd3 <- right_join(dd2,dd3,by="ids")

  model <- train(y~x1+x2+x3,data = dd3,method = "lm",trControl = trainControl(method="LOOCV"))
  list(model$results$RMSE,model$results$Rsquared,model$results$MAE)
}
stopCluster(clus)

我最近开始习惯 data.table 的语法。我发现依赖某些 dplyr 函数来节省时间更容易。可能有一些不一致之处。我期待任何改进建议。

谢谢

解决方法

如下面的基准测试所示,限速步骤是模型训练。即使将data.table子集时间减少87%,整体运行时间也几乎相同。

library(data.table)
library(caret)
library(microbenchmark)

microbenchmark(
    a = {
        dd3 <- dd1[,.SD[sample(.N,min(1,.N))],by = ids]
        dd3 <- right_join(dd2,dd3,by="ids")
    },b = {
        dd3 <- dd1[sample.int(nrow(dd1))][order(ids)][!duplicated(ids)]
        dd3[,y := dd2$y]
    },times = 10)
# Unit: microseconds
#  expr      min       lq      mean    median       uq      max neval
#     a 5151.775 5178.159 5248.1007 5214.2990 5260.367 5517.200    10
#     b  661.024  671.663  729.1066  699.2115  744.380  988.915    10

microbenchmark(
    a = {
        dd3 <- dd1[,by="ids")
        model <- train(y~x1+x2+x3,data = dd3,method = "lm",trControl = trainControl(method="LOOCV"))
        list(model$results$RMSE,model$results$Rsquared,model$results$MAE)
    },y := dd2$y]
        model <- train(y~x1+x2+x3,times = 10)
# Unit: milliseconds
#  expr      min       lq     mean   median       uq      max neval
#     a 450.1885 451.4723 454.9538 452.6399 459.6504 463.7085    10
#     b 445.2466 446.8068 449.4441 447.1629 450.0173 460.8545    10

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。