Search code examples
rpoisson

Improved inverse transform method for Poisson random variable generation in R


I am reading Section 4.2 in Simulation (2006, 4ed., Elsevier) by Sheldon M. Ross, which introducing generating a Poisson random variable by the inverse transform method.

Denote pi =P(X=xi)=e^{-λ} λ^i/i!, i=0,1,... and F(i)=P(X<=i)=Σ_{k=0}^i pi to be the PDF and CDF for Poisson, respectively, which can be computed via dpois(x,lambda) and ppois(x,lambda) in R.

There are two inverse transform algorithms for Poisson: the regular version and the improved one.

The steps for the regular version are as follows:

  1. Simulate an observation U from U(0,1)​.
  2. Set i=0​ and ​F=F(0)=p0=e^{-λ}​.
  3. If U<F​, select ​X=​i and terminate.
  4. If U >= F​, obtain i=i+1, F=F+pi​ and return to the previous step.

I write and test the above steps as follows:

### write the regular R code
pois_inv_trans_regular = function(n, lambda){
  X = rep(0, n) # generate n samples
  for(m in 1:n){
    U = runif(1)
    i = 0; F = exp(-lambda) # initialize
    while(U >= F){
      i = i+1; F = F + dpois(i,lambda) # F=F+pi
    }
  X[m] = i
  }
X
}
### test the code (for small λ, e.g. λ=3)
set.seed(0); X = pois_inv_trans_regular(n=10000,lambda=3); c(mean(X),var(X))
# [1] 3.005000 3.044079

Note that the mean and variance for Poisson(λ) are both λ, so the writing and testing for the regular code are making sense!

Next I tried the improved one, which is designed for large λ and described according to the book as follows:

  • The regular algorithm will need to make 1+λ searches, i.e. O(λ) computing complexity, which is fine when λ is small, while it can be greatly improved upon when λ is large.

  • Indeed, since a Poisson random variable with mean λ is most likely to take on one of the two integral values closest to λ , a more efficient algorithm would first check one of these values, rather than starting at 0 and working upward. For instance, let I=Int(λ) and recursively determine F(I).

  • Now generate a Poisson random variable X with mean λ by generating a random number U, noting whether or not X <= I​ by seeing whether or not ​U <= F(I)​. Then search downward starting from ​I​ in the case where X <= I​ and upward starting from ​I+1​ otherwise.

  • It is said that the improved algorithm only need 1+0.798√λ searches, i.e., having O(√λ) complexity.

I tried to wirte the R code for the improved one as follows:

### write the improved R code
pois_inv_trans_improved = function(n, lambda){
  X = rep(0, n) # generate n samples
  p = function(x) {dpois(x,lambda)} # PDF: p(x) = P(X=x) = λ^x exp(-λ)/x!
  F = function(x) {ppois(x,lambda)} # CDF: F(x) = P(X ≤ x)
  I = floor(lambda) # I=Int(λ)
  F1 = F(I); F2 = F(I+1) # two close values
  for(k in 1:n){
    U = runif(1)
    i = I
    if ( F1 < U  &  U <= F2 ) { 
      i = I+1 
    } 
    while (U <= F1){ # search downward
      i = i-1; F1 = F1 - p(i)
    }
    while (U > F2){ #  search upward
      i = i+1; F2 = F2 + p(i)
    }
    X[k] = i
  }
  X
}
### test the code (for large λ, e.g. λ=100)
set.seed(0); X = pois_inv_trans_improved(n=10000,lambda=100); c(mean(X),var(X))
# [1] 100.99900000   0.02180118

From the simulation results [1] 100.99900000 0.02180118 for c(mean(X),var(X)), which shows nonsense for the variance part. What should I remedy this issue?


Solution

  • The main problem was that F1 and F2 were modified within the loop and not reset, so eventually a very wide range of U's are considered to be in the middle.
    The second problem was on the search downward the p(i) used should be the original i, because F(x) = P(X <= x). Without this, the code hangs for low U. The easiest fix for this is to start i = I + 1. Then "in the middle" if statement isn't needed.

    pois_inv_trans_improved = function(n, lambda){
      X = rep(0, n) # generate n samples
      p = function(x) {dpois(x,lambda)} # PDF: p(x) = P(X=x) = λ^x exp(-λ)/x!
      `F` = function(x) {ppois(x,lambda)} # CDF: F(x) = P(X ≤ x)
      I = floor(lambda) # I=Int(λ)
      F1 = F(I); F2 = F(I+1) # two close values
      for(k in 1:n){
        U = runif(1)
        i = I + 1
        # if ( F1 < U  &  U <= F2 ) { 
        #   i = I + 1
        # }
        F1tmp = F1
        while (U <= F1tmp){ # search downward
          i = i-1; F1tmp = F1tmp - p(i);  
        }
        F2tmp = F2
        while (U > F2tmp){ #  search upward
          i = i+1; F2tmp = F2tmp + p(i)
        }
        X[k] = i
      }
      X
    }
    

    This gives:

    [1] 100.0056 102.2380