Search code examples
rdata.tabledtplyr

How to filter a data.table based on an uncertain number of conditions?


Given the following data.table in R:

set.seed(123666)
dt <- data.table(sample1 = sample(10), 
                 sample2 = sample(10),
                 sample3 = sample(10),
                 sample4 = sample(10), 
                 sample5 = sample(10),
                 sample6 = sample(10))
dt
    sample1 sample2 sample3 sample4 sample5 sample6
 1:       2       6       3       9       1       2
 2:      10       9      10       3       7       5
 3:       6      10       8       5       5       1
 4:       8       2       9       8       6       6
 5:       5       4       5      10      10       8
 6:       7       1       7       4       4      10
 7:       4       3       1       6       3       7
 8:       1       5       6       1       2       3
 9:       3       7       2       2       8       9
10:       9       8       4       7       9       4

Let's assume the first 3 samples are in group_a and the last 3 samples are in group_b. Now we want to filter rows that satisfy the condition of having at least 2 out of 3 samples greater than 2 in each group. In the given case, we can achieve this using the following code:

group_a <- paste0('sample', seq(1,3))
group_b <- paste0('sample', seq(4,6))

dt[rowSums(dt[, ..group_a, with = FALSE] > 2) >= 2 & rowSums(dt[, ..group_b, with = FALSE] > 2) >= 2]
   sample1 sample2 sample3 sample4 sample5 sample6
1:      10       9      10       3       7       5
2:       6      10       8       5       5       1
3:       8       2       9       8       6       6
4:       5       4       5      10      10       8
5:       7       1       7       4       4      10
6:       4       3       1       6       3       7
7:       3       7       2       2       8       9
8:       9       8       4       7       9       4

Now, let's consider a data.table where each column still represents a sample name, but the number of samples is uncertain. There is an additional variable group describing the grouping of samples:

group <- paste0('sample', seq(1,6))
group_id <- c(rep('group_a', 3), rep('group_b', 3))
names(group) <- group_id 
group
  group_a   group_a   group_a   group_b   group_b   group_b 
"sample1" "sample2" "sample3" "sample4" "sample5" "sample6"

How to accomplish this task using the data.table syntax and with the most concise code possible?


Solution

  • You can split on the names and iterate over the list to subset the columns and check your conditions then reduce the result to subset the rows:

    library(data.table)
    
    dt[Reduce(`&`, lapply(split(group, names(group)), \(x) rowSums(dt[, .SD, .SDcols = x] > 1) >= 2 )), ]
    
       sample1 sample2 sample3 sample4 sample5 sample6
     1:       2       6       3       9       1       2
     2:      10       9      10       3       7       5
     3:       6      10       8       5       5       1
     4:       8       2       9       8       6       6
     5:       5       4       5      10      10       8
     6:       7       1       7       4       4      10
     7:       4       3       1       6       3       7
     8:       1       5       6       1       2       3
     9:       3       7       2       2       8       9
    10:       9       8       4       7       9       4
    

    All rows meet your example criteria, but if we change it to at least two values greater than two, we can see that it works:

    dt[Reduce(`&`, lapply(split(group, names(group)), \(x) rowSums(dt[, .SD, .SDcols = x] > 2) >= 2 )), ]
    
       sample1 sample2 sample3 sample4 sample5 sample6
    1:      10       9      10       3       7       5
    2:       6      10       8       5       5       1
    3:       8       2       9       8       6       6
    4:       5       4       5      10      10       8
    5:       7       1       7       4       4      10
    6:       4       3       1       6       3       7
    7:       3       7       2       2       8       9
    8:       9       8       4       7       9       4
    

    @r2evans suggested an alternative that might offer better performance in the context of a large number of groups.

    dt[rowSums(sapply(split(group, names(group)), \(x) rowSums(dt[, .SD, .SDcols = x] <= 2) >= 2 )) == 0, ]