Search code examples
rvectormemory-efficient

Stop copy-on-modify behavior in R in a while loop


I'm doing a rejection sampling on R that needs to be as efficient as possible. Here is my raw code :

N <- 1e8
x <- rexp(N, 3) + rexp(N, 3)
todo <- runif(N, -1, 1) < cos(3.2*pi*x)

while(any(todo)){
  x[todo] <- rexp(sum(todo), 3) + rexp(sum(todo), 3)
  todo[todo] <- runif(sum(todo), -1, 1) < cos(3.2*pi*x[todo])
}

I was reading about copy-and-modify on https://adv-r.hadley.nz/names-values.html, and decided to use the lobstr package to find if there was any object concerned (with the tracemem() function), and lo-and-behold, the logical vector todo keeps getting copied in the while loop !

If I do the following test :

library(lobstr)
N <- 1e8
x <- rexp(N, 3) + rexp(N, 3)
todo <- runif(N, -1, 1) < cos(3.2*pi*x)
cat(tracemem(todo), "\n")

while(any(todo)){
  x[todo] <- rexp(sum(todo), 3) + rexp(sum(todo), 3)
  todo[todo] <- runif(sum(todo), -1, 1) < cos(3.2*pi*x[todo])
}

I get the following result (confirming my worries) :

[1] "0x38f4938"
[1] "0x1959ff30"
[1] "0x38f4938"
[1] "0x173bb6f0"
[1] "0x38f4938"
[1] "0x4caf788"
[1] "0x38f4938"
[1] "0x1a801628"
[1] "0x38f4938"
[1] "0x18f36768"
[1] "0x38f4938"
[1] "0x4e4d478"
[1] "0x38f4938"
[1] "0x195b93d8"
[1] "0x38f4938"
[1] "0x3f59fe0"
[1] "0x38f4938"
[1] "0x45ebf40"
[1] "0x38f4938"
[1] "0x1a42bdd8"
[1] "0x38f4938"
[1] "0x16c72ba0"

Can someone please help me get rid of this time-consuming behavior ? I tried, without success, the following declarations :

  1. todo <- logical(N)

  2. todo <- list(logical(N))

Edit : also, any help with improving the time efficiency of this bottleneck in my code will be very much appreciated...


Solution

  • I don't think that's possible for base R, but you can optimise this procedure using the data.table package. Try this and you will see no copy is made (I also did some other minor changes to further optimise your code)

    library(data.table)
    
    N <- 1e7; i <- 1:N
    dt <- list(x = double(N), todo = logical(N)); setDT(dt)
    cat(tracemem(dt), "\n")
    cat(tracemem(N), "\n")
    cat(tracemem(i), "\n")
    
    while(N > 0L){
      set(dt, i, "x", rexp(N, 3) + rexp(N, 3))
      set(dt, i, "todo", runif(N, -1, 1) < cos(3.2*pi*dt[i]$x))
      N <- length(i <- which(dt$todo))
    }
    

    It takes about 3 seconds to run, which is a bit too long. I think there is room for further improvement.

    system.time({
      N <- 1e7; i <- 1:N
      dt <- list(x = double(N), todo = logical(N)); setDT(dt)
      
      while(N > 0L){
        set(dt, i, "x", rexp(N, 3) + rexp(N, 3))
        set(dt, i, "todo", runif(N, -1, 1) < cos(3.2*pi*dt[i]$x))
        N <- length(i <- which(dt$todo))
      }
    })
    

    Result

     user  system elapsed 
     3.26    0.10    3.34