Search code examples
rdata.tablehmisc

Fastest way to check to which bracket a value belongs for large data


The data below has lists in the strata column. I would like to use the list in the strata column as the cut-off values. The solution provided below comes from r2evans.

The problem is that I am using quite a large data-set. As a result I was wondering whether there are faster ways to achieve the same thing.

# DATA

library(data.table)
library(Hmisc)
dat <- structure(list(values = c(25, 11, 21, 15), strata = list(c(10, 20, 30, 40), c(10, 20, 30), c(10, 20), c(10, 30))), row.names = c(NA, 
-2L), class = c("data.table", "data.frame"))

#    values      strata
# 1:     25 10,20,30,40
# 2:     11    10,20,30
# 3:     21       10,20
# 4:     15       10,30

# CURRENT SOLUTION

setDT(dat)
dat[, cat := mapply(cut, values, strata, oneval=FALSE)]
dat
#    values      strata     cat
#     <num>      <list>  <fctr>
# 1:     25 10,20,30,40 [20,30)
# 2:     11    10,20,30 [10,20)
# 3:     21       10,20 <NA>
# 4:     15       10,30 [10,30]

EDIT:

Result for the actual data:

# 2 minutes
dat[, cat := mapply(cut2, values, strata, oneval=FALSE)]

# 51 seconds
dat[, cat := mapply(cut, values, strata, oneval=FALSE)]

# 21 seconds
solution by arau

# 8 seconds
solution by Uwe

# 44 seconds to load, 0.2 seconds to compute
Solution by onyambu

Solution

  • If you have to speed up things consider using Rcpp. (Note in this case I just wrote a purely c++ function which I read into R using Rcpp)

    Rcpp::cppFunction(
      "std::vector<std::string> interval(std::vector<int> &x,
                                         std::vector<std::vector<int>> &y){
        std::vector<std::string> z;
        z.reserve(x.size());
        std::transform(x.begin(), x.end(), y.begin(), std::back_inserter(z), 
          [&](int a, std::vector<int> b) {
            auto it = std::find_if(b.begin(), b.end(), [=](int w) {return a < w;});
            return it==b.begin() | it == b.end()? \"NA\":
                  '[' + std::to_string(*(it-1)) + ',' + std::to_string(*it) + ')';
        });
        return z;
      }"
    )
    
    dat[, cat:=interval(values, strata)]
    dat
      values      strata     cat
    1:     25 10,20,30,40 [20,30)
    2:     11    10,20,30 [10,20)
    3:     21       10,20      NA
    4:     15       10,30 [10,30)
    

    speed comparison:

    Note that for a small dataset:

     microbenchmark::microbenchmark(OP_solution=dat[, cat := mapply(cut, values, strata, oneval=FALSE)], Rcpp=dat[, cat:=interval(values, strata)], Arau_solution = arau(dat))
    Unit: microseconds
              expr      min        lq     mean    median        uq      max neval
       OP_solution  843.000  878.2010  917.592  898.7510  931.1515   1564.9   100
              Rcpp  301.301  317.8515  339.930  326.7015  337.2010   1425.4   100
     Arau_solution 3917.300 4059.9010 7295.939 4214.5005 4381.9510 305882.8   100
    

    and for big data:

    dat<-do.call(rbind, replicate(1000, dat, simplify = F))
    microbenchmark::microbenchmark(OP_solution=dat[, cat := mapply(cut, values, strata, oneval=FALSE)], Rcpp=dat[, cat:=interval(values, strata)], Arau_solution = long_merge(dat))
    Unit: milliseconds
              expr        min         lq       mean     median         uq        max neval
       OP_solution 418.765001 438.206901 463.993935 454.652851 472.974501 782.981200   100
              Rcpp   1.547401   1.685352   1.895221   1.759301   1.917451   5.185102   100
     Arau_solution 227.064100 239.994302 260.242264 251.936452 264.225652 605.567101   100
    

    Note that Rcpp outperforms the rest. Arun solution is 200 times slower than the Rcpp solution

    --- where:

    arau <- function(dat){
      dat[, tmp_id := .I]  # create a temporary identifier for each value row
      
      # NB: can remove the sort call if the strata are already ordered increasing
      dat_long = dat[, .(values, lb = sort(unlist(strata))), by = tmp_id]
      dat_long[, ub := shift(lb, -1), by = tmp_id]
      dat_long = dat_long[!is.na(ub)]
      
      dat_result = merge(
        dat, dat_long[lb <= values & values < ub, -'values'], by = 'tmp_id', all.x = T)
      dat_result[!is.na(lb) & !is.na(ub), cat := paste0('[', lb, ', ', ub, ')')]
      dat_result[, .(values, strata, cat)]
    }
    

    EDIT

    It is flawed when @uwe claims that the RCPP code provided is slow yet they only compare the last part of their code to the Rcpp code. Why not compare their whole code? Here is the microbenchmark of the two codes--providing the times. Not the graphs.

    Rcpp::cppFunction(
      "std::vector<std::string> interval(std::vector<int> &x,
                                         std::vector<std::vector<int>> &y){
        std::vector<std::string> z;
        z.reserve(x.size());
        std::transform(x.begin(), x.end(), y.begin(), std::back_inserter(z), 
          [&](int a, std::vector<int> b) {
            auto it = std::find_if(b.begin(), b.end(), [=](int w) {return a < w;});
            return it==b.begin() | it == b.end()? \"NA\":
                  '[' + std::to_string(*(it-1)) + ',' + std::to_string(*it) + ')';
        });
        return z;
      }"
    )
    
    
    Rcpp_fun <- function(dat1){
      dat <- copy(dat1)
      dat[, cat:=interval(values, strata)][]
    }
    
    uwe <- function(dat1){
      dat <- copy(dat1)
      setDT(dat)[, strata_id := .I]
      lut <- dat[, .(lo = head(strata[[1]], -1L), 
                     hi = tail(strata[[1]], -1L)), by = strata_id][
                       , cat := sprintf("[%i,%i)", lo, hi)][]
      
      
      dat[lut, on = .(strata_id, values >= lo, values < hi), cat := i.cat]
      dat[, strata_id:=NULL][]
    }
    
    dat1 <- structure(list(values = c(25, 11, 21, 15), strata = list(c(10, 20, 30, 40), c(10, 20, 30), c(10, 20), c(10, 30))), row.names = c(NA, 
                                                                                                                                            -2L), class = c("data.table", "data.frame"))
    dat1<- rbindlist(replicate(1e5, dat1, simplify = FALSE))
    microbenchmark::microbenchmark(Rcpp_fun(dat1), uwe(dat1))
    

    Results?

    microbenchmark::microbenchmark(Rcpp_fun(dat1), uwe(dat1))
    Unit: milliseconds
               expr       min        lq      mean   median        uq       max neval
     Rcpp_fun(dat1)  157.5878  173.5254  216.1008  189.931  227.0475   641.286   100
          uwe(dat1) 5905.8138 6332.8186 7096.8007 6689.716 7270.8831 15575.635   100
    

    Clearly the Results show that Rcpp is ATLEAST 30X faster than @uwe code. Why would someone claim Their code is faster yet the graphs provided shows otherwise? Note that the code above takes roughly 5 minutes to do the benchmark.

    Unless shown otherwise, the Rcpp code provided is the fastest of the 3 solutions provided.