Search code examples
rperformancematrixindexing

Identifying data frame rows in R with specific pairs of values in two columns


I would like to identify all rows in a data frame (or matrix) whose values in column 1 and 2 match a specific pair. For example, if I have a matrix

testmat = rbind(c(1,1), c(1,2), c(1,4), c(2,1), c(2,4), c(3,4), c(3,10))

I would like to identify the rows that contain any of the following pairs, i.e. all rows that contain a combination of either 1,2 or 2,4 in their first and second columns

of_interest = rbind(c(1,2), c(2,4))

The following does not work

which(testmat[, 1] %in% of_interest[, 1] & testmat[, 2] %in% of_interest[, 2])

because, as expected, it returns all combinations of 1,2 in the first column and 2,4 in the second (i.e. rows 2,3,5 rather than just rows 2 and 5 as desired), so that the row [1,4] is included even though this is not one of the pairs I'm querying for. There must be some simple way to use the which( ... %in% ...) to match specific pairs like this, but I haven't been able to find an example of this that works.

Note that I need the positions/row numbers of the rows which match the desired condition.


Solution

  • Standard approach

    I assume as you're using which() you want the position, rather than just whether there is a match. You can cbind() the row number to testmat and then merge() this with of_interest.

    merge(
        cbind(testmat, seq_len(nrow(testmat))),
        of_interest
    ) |> setNames(c("x", "y", "row_num"))
    
    #   x y row_num
    # 1 1 2       2
    # 2 2 4       5
    

    Rcpp approach with very large matrix

    You mention in your comment that you have 1e8 rows. This makes me think two things:

    1. Don't merge() as this will coerce matrices to data frames, i.e. copy each column into a memory-contiguous vector, which will be very expensive.
    2. If of_interest is also large, you want to break the loop as soon as match is found rather than continuing to iterate. See this question for performance advantages.

    Given this I would avoid using which() or other approaches which do not exit early. Here's some Rcpp code that should be much faster than merge() with large datasets:

    Rcpp::cppFunction("
    IntegerVector get_row_position(NumericMatrix testmat, NumericMatrix of_interest) {
        const R_xlen_t nrow_testmat = testmat.nrow();
        const R_xlen_t nrow_of_interest = of_interest.nrow();
        IntegerVector result;
    
        // loop through the rows of testmat
        for (R_xlen_t i = 0; i < nrow_testmat; ++i) {
            for (R_xlen_t j = 0; j < nrow_of_interest; ++j) {
                if (testmat(i, 0) == of_interest(j, 0) && testmat(i, 1) == of_interest(j, 1)) {
                    result.push_back(i + 1); // because of 1-indexing
                    break; // leave inner loop early
                }
            }
        }
        return result;
    }
    ")
    
    get_row_position(testmat, of_interest)
    # [1] 2 5
    

    Note: This previously accessed rows as sub-matrices e.g. NumericMatrix::Row test_row = testmat(i, _);, which is more idiomatic Rcpp code than a double for-loop with matrix indexing but it turns out after benchmarking it's much slower so I've updated it to just compare directly. See the edit history for the previous version.

    A quick benchmark

    I updated the above function after the nice answer from jblood94 which showed the previous approach was slower than some base R approaches. I use the 100m row benchmark from that answer against the rowmatch3() function, which was the fastest (and around 8 times faster than my previous answer). This slightly updated approach is around 5 times faster than rowmatch3().

    testmat <- `dim<-`(as.numeric(sample(4e3, 2e8, 1)), c(1e8, 2))
    matchmat <- unique(`dim<-`(sample(4e3, 10, 1), c(5, 2)))
    microbenchmark::microbenchmark(
        get_row_position = get_row_position(testmat, matchmat),
        rowmatch3 = rowmatch3(testmat, matchmat),
        check = "identical",
        unit = "relative",
        times = 10L
    )
    
    # Unit: relative
    #              expr      min       lq     mean   median       uq      max neval cld
    #  get_row_position 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000    10  a 
    #         rowmatch3 5.262158 5.309956 5.405731 5.426385 5.469428 5.321671    10   b