Search code examples
pythonmontecarlomarkov-chains

3 nested loops: Optimizing a simple simulation for speed


Background

I've come across a puzzle. Here it is:

One day, an alien comes to Earth. Every day, each alien does one of four things, each with equal probability to:

  • Kill himself
  • Do nothing
  • Split himself into two aliens (while killing himself)
  • split himself into three aliens (while killing himself)

What is the probability that the alien species eventually dies out entirely?

Link to the source and the solution, problem #10

Unfortunately, I haven't been able to solve the problem theoretically. Then I moved on to simulate it with a basic Markov Chain and Monte Carlo simulation in mind.

This was not asked to me in an interview. I learned the problem from a friend, then found the link above while searching for mathematical solutions.

Reinterpreting the question

We start with the number of aliens n = 1. n has a chance to not change, be decremented by 1, be incremented by 1, and be decremented by 2, %25 for each. If n is incremented, i.e. aliens multiplied, we repeat this procedure for n times again. This corresponds to each alien will do its thing again. I have to put an upper limit though, so that we stop simulating and avoid a crash. n is likely to increase and we're looping n times again and again.

If aliens somehow go extinct, we stop simulating again since there's nothing left to simulate.

After n reaches zero or the upper limit, we also record the population (it will be either zero or some number >= max_pop).

I repeat this many times and record every result. At the end, number of zeros divided by total number of results should give me an approximation.

The code

from random import randint
import numpy as np

pop_max = 100
iter_max = 100000

results = np.zeros(iter_max, dtype=int)

for i in range(iter_max):
    n = 1
    while n > 0 and n < pop_max:
        for j in range(n):
            x = randint(1, 4)
            if x == 1:
                n = n - 1
            elif x == 2:
                continue
            elif x == 3:
                n = n + 1
            elif x == 4:
                n = n + 2
    results[i] = n

print( np.bincount(results)[0] / iter_max )

iter_max and pop_max can be changed indeed, but I thought if there are 100 aliens, the probability that they go extinct would be negligibly low. This is just a guess though, I haven't done anything to calculate a (more) proper upper limit for population.

This code gives promising results, fairly close to the real answer which is approximately %41.4.

Some outputs

> python aliens.py
0.41393
> python aliens.py
0.41808
> python aliens.py
0.41574
> python aliens.py
0.4149
> python aliens.py
0.41505
> python aliens.py
0.41277
> python aliens.py
0.41428
> python aliens.py
0.41407
> python aliens.py
0.41676

Aftermath

I'm okay with the results but I can't say the same for the time this code takes. It takes around 16-17 seconds :)

How can I improve the speed? How can I optimize loops (especially the while loop)? Maybe there's a much better approach or better models?


Solution

  • You can vectorize your inner loop by generating n random integers all at once with numpy (much faster), and get rid of all your if statements using arithmatic instead of boolean logic.

    while...: 
        #population changes by (-1, 0, +1, +2) for each alien
        n += np.random.randint(-1,3, size=n).sum()
    

    Using your exact code for everything else (you could probably find other optimizations elsewhere) I went from 21.2 sec to 4.3 sec using this one change.

    Without changing the algorithm (ie solving with a method other than monte carlo) I don't see any other sweeping changes that could make it much faster until you get into compiling to machine code (which is fortunately very easy if you have numba installed).

    I won't give an entire tutorial on just-in-time compilation that numba performs, but instead, I'll just share my code and make note of the changes I made:

    from time import time
    import numpy as np
    from numpy.random import randint
    from numba import njit, int32, prange
    
    @njit('i4(i4)')
    def simulate(pop_max): #move simulation of one population to a function for parallelization
        n = 1
        while 0 < n < pop_max:
            n += np.sum(randint(-1,3,n))
        return n
    
    @njit('i4[:](i4,i4)', parallel=True)
    def solve(pop_max, iter_max):
        #this could be easily simplified to just return the raio of populations that die off vs survive to pop_max
        # which would save you some ram (though the speed is about the same)
        results = np.zeros(iter_max, dtype=int32) #numba needs int32 here rather than python int
        for i in prange(iter_max): #prange specifies that this loop can be parallelized
            results[i] = simulate(pop_max)
        return results
    
    pop_max = 100
    iter_max = 100000
    
    t = time()
    print( np.bincount(solve(pop_max, iter_max))[0] / iter_max )
    print('time elapsed: ', time()-t)
    

    Compilation with parallelization gets the speed of evaluation down to about 0.15 seconds on my system.