Search code examples
rfilterdata.tablecumsum

Filter R data.table based on groupings created using cumulative sum of a column


I need an efficient data.table solution to filter to just the first and last instance for each 300 of a cumulative sum of a column. My real dataset is millions rows, so I am NOT looking for a looped solution.

#Example data:
  dt <- data.table(idcolref=c(1:1000),y=rep(10,1000))

An example loop that does what I'd like is below, but it is far too slow to be useful for a large data.table.

###example of a loop that produces the result I want but is too slow
  library(foreach)
  dt[,grp:=1,]
  dt[,cumsum:=0,]
  grp <- 1
  foreach(a=2:nrow(dt))%do%{
    dt[a,"cumsum"]<-dt[a,"y"]+dt[a-1,"cumsum"]
    if(dt[a,"cumsum"]>300){
      dt[a,"grp"] <- grp
      grp <- grp+1
      dt[a,"cumsum"]<-0
    }else{
      dt[a,"grp"]<-dt[a-1,"grp"]
    }
  }
  dt.desired <- foreach(a=2:nrow(dt),.combine=rbind)%do%{
    if(dt[a,"grp"]!=dt[a-1,"grp"]){
      dt[c(a-1,a),]
    }
  }
  dt.desired <- rbind(dt[1,],dt.desired)
  dt.desired <- rbind(dt.desired,dt[nrow(dt),])

How can I get the same result using fast vectorized data.table functions? Thanks!


Solution

  • I think I've interpreted your requirement correctly:

    1. you want to calculate the cumulative sum of a vector (column).
    2. If the cumulative sum gets to 300 you want to reset it back to 0.
    3. Each time you reset to 0, you want to say those values of the vector are in a new group.
    4. You want to select the first and last rows of each group

    If this is the case, you can write your own fast 'vectorised' function in Rcpp

    library(data.table)
    
    dt <- data.table(x=rep(5,1e7),y=rep(10,1e7))
    ## adding a row index to keep track of which rows are returned
    dt[, id := .I]
    
    library(Rcpp)
    
    cppFunction('Rcpp::NumericVector findGroupRows(Rcpp::NumericVector x) {
    
      int cumsum = 0;
      int grpCounter = 0;
      size_t n = x.length();
      Rcpp::NumericVector groupedCumSum(n);
    
      for ( size_t i = 0; i < n; i++) {
        cumsum += x[i];
        if (cumsum > 300) {
          cumsum = 0;
          grpCounter++;
        }
        groupedCumSum[i] = grpCounter;
      }
      return groupedCumSum;
    }')
    
    dt[, grp := findGroupRows(y)]
    
    dt[ dt[, .I[c(1, .N)], by = grp]$V1]