Search code examples
rdata.tablematrix-multiplication

R values that go into matrix multiplication


What is the fastest approach to saving unique values that go into matrix multiplication (without 0)?
For example, if I have a data.table object

library(data.table)
A = data.table(j3=c(3,0,3),j5=c(0,5,5),j7=c(0,7,0),j8=c(8,0,8))

I would like to see which unique values go into A*transpose(A) (or as.matrix(A) %*% as.matrix(t(A))). Right now, I can do it using for loops as:

B=t(A)
L = list()
models = c('A1','A2','A3')

for(i in 1:nrow(A)){
    for(j in 1:ncol(B)){
        u = union(unlist(A[i,]),B[,j])
        u = u[u!=0] # remove 0
        L[[paste(models[i],models[j])]]= u
    }
}

However, is there a faster and more RAM-efficient way? The output doesn't have to be a list object, as in my case, it can be a data.table (data.frame) as well. Also, the order of values is not important. For example, 3 5 8 is as good as 5 3 8, 8 5 3 etc.

Any help is appreciated.

EDIT: So as.matrix(A) %*% as.matrix(t(A)) is:

     [,1] [,2] [,3]  
[1,]   73    0   73  
[2,]    0   74   25
[3,]   73   25   98

The first element is calculated as 3*3+0*0+0*0+8*8 = 73, the second element is 3*0+0*5+0*7+8*0 = 0, etc. I need unique numbers that go to this calculation but without 0.

Therefore outputs (saved in the list L) are:

> L  
$`A1 A1`  
[1] 3 8

$`A1 A2`  
[1] 3 8 5 7

$`A1 A3`   
[1] 3 8 5

$`A2 A1`  
[1] 5 7 3 8

$`A2 A2`  
[1] 5 7

$`A2 A3`   
[1] 5 7 3 8

$`A3 A1`  
[1] 3 5 8

$`A3 A2`   
[1] 3 5 8 7

$`A3 A3`   
[1] 3 5 8

Once again, the output doesn't have to be a list object. I would prefer data.table if it is doable. Is it possible to rewrite my approach as Rcpp function?


Solution

  • Potential optimizations

    Following up on @user2554330's answer, note that if A is an m-by-n matrix, then AAT = A %*% t(A) (equivalently tcrossprod(A)) is an m-by-m symmetric matrix. AAT[i, j] and AAT[j, i] are computed using the same entries of A, so you only need to inspect m*(m+1)/2 pairs of rows of A, not m*m.

    You can do even better by finding and caching the unique elements of each row before pairing them. Preprocessing in this way avoids redundant computation and should noticeably improve performance when m << n.

    Limitations

    Another aspect of the problem is how unique works under the hood. unique has an argument nmax that you can use to specify an expected maximum number of unique elements. From ?duplicated:

    Except for factors, logical and raw vectors the default nmax = NA is equivalent to nmax = length(x). Since a hash table of size 8*nmax bytes is allocated, setting nmax suitably can save large amounts of memory. For factors it is automatically set to the smaller of length(x) and the number of levels plus one (for NA). If nmax is set too small there is liable to be an error: nmax = 1 is silently ignored.

    Long vectors are supported for the default method of duplicated, but may only be usable if nmax is supplied.

    These comments apply to unique as well. Since you have a 300-by-4e+07 matrix, you would be evaluating (with preprocessing):

    • unique(<4e+07-length vector>), 300 times,
    • unique(<up to 8e+07-length vector>), 299*300/2 times.

    That can consume a lot of memory if you don't know anything about your matrix that might allow you to set nmax. And it can take a long time if you don't have access to many CPUs.

    So I agree with comments asking you to consider why you need to do this at all and whether your underlying problem has a nicer solution.

    Two answers

    FWIW, here are two approaches to your general problem that actually take advantage of symmetry. f and g are without and with preprocessing. [[.utri allows you to extract elements from the return value, an m*(m+1)/2-length list, as if it were an m-by-m matrix. as.matrix.utri constructs the full, symmetric m-by-m list matrix.

    f <- function(A, nmax = NA) {
      a <- seq_len(nrow(A))
      J <- cbind(sequence(a), rep.int(a, a))
      FUN <- function(i) {
        if (i[1L] == i[2L]) {
          x <- A[i[1L], ]
        } else {
          x <- c(A[i[1L], ], A[i[2L], ])
        }
        unique.default(x[x != 0], nmax = nmax)
      }
      res <- apply(J, 1L, FUN, simplify = FALSE)
      class(res) <- "utri"
      res
    }
    
    g <- function(A, nmax = NA) {
      l <- lapply(asplit(A, 1L), function(x) unique.default(x[x != 0], nmax = nmax))
      a <- seq_along(l)
      J <- cbind(sequence(a), rep.int(a, a))
      FUN <- function(i) {
        if (i[1L] == i[2L]) {
          l[[i[1L]]]
        } else {
          unique.default(c(l[[i[1L]]], l[[i[2L]]]))
        }
      }
      res <- apply(J, 1L, FUN, simplify = FALSE)
      class(res) <- "utri"
      res
    }
    
    `[[.utri` <- function(x, i, j) {
      stopifnot(length(i) == 1L, length(j) == 1L)
      class(x) <- NULL
      if (i <= j) {
        x[[i + (j * (j - 1L)) %/% 2L]]
      } else {
        x[[j + (i * (i - 1L)) %/% 2L]]
      }
    }
    
    as.matrix.utri <- function(x) {
      p <- length(x)
      n <- as.integer(round(0.5 * (-1 + sqrt(1 + 8 * p))))
      i <- rep.int(seq_len(n), n)
      j <- rep.int(seq_len(n), rep.int(n, n))
      r <- i > j
      ir <- i[r]
      i[r] <- j[r]
      j[r] <- ir
      res <- x[i + (j * (j - 1L)) %/% 2L]
      dim(res) <- c(n, n)
      res
    }
    

    Here is a simple test on a 4-by-4 integer matrix:

    mkA <- function(m, n) {
      A <- sample(0:(n - 1L), size = as.double(m) * n, replace = TRUE, 
                  prob = rep.int(c(n - 1, 1), c(1L, n - 1L)))
      dim(A) <- c(m, n)
      A
    }
    
    set.seed(1L)
    A <- mkA(4L, 4L)
    A
    ##      [,1] [,2] [,3] [,4]
    ## [1,]    0    0    2    3
    ## [2,]    0    1    0    0
    ## [3,]    2    1    0    3
    ## [4,]    1    2    0    0
    
    identical(f(A), gA <- g(A))
    ## [1] TRUE
    
    gA[[1L, 1L]] # used for 'tcrossprod(A)[1L, 1L]'
    ## [1] 2 3
    
    gA[[1L, 2L]] # used for 'tcrossprod(A)[1L, 2L]'
    ## [1] 2 3 1
    
    gA[[2L, 1L]] # used for 'tcrossprod(A)[2L, 1L]'
    ## [1] 2 3 1
    
    gA # under the hood, an 'm*(m+1)/2'-length list
    ## [[1]]
    ## [1] 2 3
    ## 
    ## [[2]]
    ## [1] 2 3 1
    ## 
    ## [[3]]
    ## [1] 1
    ## 
    ## [[4]]
    ## [1] 2 3 1
    ## 
    ## [[5]]
    ## [1] 1 2 3
    ## 
    ## [[6]]
    ## [1] 2 1 3
    ## 
    ## [[7]]
    ## [1] 2 3 1
    ## 
    ## [[8]]
    ## [1] 1 2
    ## 
    ## [[9]]
    ## [1] 2 1 3
    ## 
    ## [[10]]
    ## [1] 1 2
    ## 
    ## attr(,"class")
    ## [1] "utri"
    
    mgA <- as.matrix(gA) # the full, symmetric, 'm'-by-'m' list matrix
    mgA
    ##      [,1]      [,2]      [,3]      [,4]     
    ## [1,] integer,2 integer,3 integer,3 integer,3
    ## [2,] integer,3 1         integer,3 integer,2
    ## [3,] integer,3 integer,3 integer,3 integer,3
    ## [4,] integer,3 integer,2 integer,3 integer,2
    
    mgA[1L, ] # used for first row of 'tcrossprod(A)'
    ## [[1]]
    ## [1] 2 3
    ## 
    ## [[2]]
    ## [1] 2 3 1
    ## 
    ## [[3]]
    ## [1] 2 3 1
    ## 
    ## [[4]]
    ## [1] 2 3 1
    
    ## If you need names
    dimnames(mgA) <- rep.int(list(sprintf("A%d", seq_len(nrow(mgA)))), 2L)
    mgA["A1", ]
    ## $A1
    ## [1] 2 3
    ## 
    ## $A2
    ## [1] 2 3 1
    ## 
    ## $A3
    ## [1] 2 3 1
    ## 
    ## $A4
    ## [1] 2 3 1
    
    ## If you need an 'm'-by-'m' 'data.table' result
    DT <- data.table::as.data.table(mgA)
    DT
    ##       A1    A2    A3    A4
    ## 1:   2,3 2,3,1 2,3,1 2,3,1
    ## 2: 2,3,1     1 1,2,3   1,2
    ## 3: 2,3,1 1,2,3 2,1,3 2,1,3
    ## 4: 2,3,1   1,2 2,1,3   1,2
    

    And here are two benchmarks on two large integer matrices, showing that preprocessing can help quite a bit:

    set.seed(1L)
    A <- mkA(100L, 1e+04L)
    microbenchmark::microbenchmark(f(A), g(A), times = 10L, setup = gc(FALSE))
    ## Unit: milliseconds
    ##  expr       min        lq      mean    median        uq      max neval
    ##  f(A) 2352.0572 2383.3100 2435.7954 2403.8968 2431.6214 2619.553    10
    ##  g(A)  843.0206  852.5757  858.7262  858.2746  863.8239  881.450    10
    
    A <- mkA(100L, 1e+06L)
    microbenchmark::microbenchmark(f(A), g(A), times = 10L, setup = gc(FALSE))
    ## Unit: seconds
    ##  expr       min        lq      mean    median        uq       max neval
    ##  f(A) 290.93327 295.54319 302.57001 301.17810 307.50226 318.14203    10
    ##  g(A)  72.85608  73.83614  76.67941  76.57313  77.78056  83.73388    10