Search code examples
c#algorithmmathstatisticsbernoulli-probability

How to get the detailed results (probability tree) of executing a bernoulli experiment a large number of times


Supposing the following experiment : Execute the same bernoulli trial (with the probability of success P) N number of times

I need the following information : All the possible sequences of success/failure with its probability to happen.

Example : A Bernouilli experiment with a probability of success P = 40% executed 3 times would yield the following results (S is a success, F is a failure) :

FFF 0.216

SFF 0.144

FSF 0.144

SSF 0.096

FFS 0.144

SFS 0.096

FSS 0.096

SSS 0.064

I tried to bruteforce it to obtain the results, but it chokes rapidly with only N = 25, I get an OutOfMemoryException...

using System;
using System.Linq;
using System.Collections.Generic;
using System.Text.RegularExpressions;

namespace ConsoleApplication
{
    class Program
    {
        static Dictionary<string, double> finalResultProbabilities = new Dictionary<string, double>();

        static void Main(string[] args)
        {
            // OutOfMemoryException if I set it to 25 :(
            //var nbGames = 25;
            var nbGames = 3;
            var probabilityToWin = 0.4d;

            CalculateAverageWinningStreak(string.Empty, 1d, nbGames, probabilityToWin);

            // Do something with the finalResultProbabilities data...
        }

        static void CalculateAverageWinningStreak(string currentResult, double currentProbability, int nbGamesRemaining, double probabilityToWin)
        {
            if (nbGamesRemaining == 0)
            {
                finalResultProbabilities.Add(currentResult, currentProbability);
                return;
            }

            CalculateAverageWinningStreak(currentResult + "S", currentProbability * probabilityToWin, nbGamesRemaining - 1, probabilityToWin);
            CalculateAverageWinningStreak(currentResult + "F", currentProbability * (1 - probabilityToWin), nbGamesRemaining - 1, probabilityToWin);
        }
    }
}

I need to be able to support up to N = 3000 in a timely manner (obtaining the result in less than 3 seconds for any P)

Is there a mathematical way to do this optimally?


Solution

  • Here's a different approach, which is exact and fast enough being only quadratic. The expected value of the longest win streak is equal to

     n
    sum Pr(there exists a win streak of length at least k).
    k=1
    

    We reason about the probability as follows. Either the record opens with a length-k win streak (probability pwin**k), or it opens with j wins for some j in 0..k-1 followed by a loss (probability pwin**j * (1 - pwin)), on which condition the probability is equal to the probability of a length-k win streak in n - (j + 1) tries. We use memoization to evaluate the recurrence that this logic implies in pwinstreak; the faster version in fastpwinstreak uses algebra to avoid repeated summations.

    def avglongwinstreak(n, pwin):
        return sum(fastpwinstreak(n, pwin, k) for k in range(1, n + 1))
    
    
    def pwinstreak(n, pwin, k):
        memo = [0] * (n + 1)
        for m in range(k, n + 1):
            memo[m] = pwin**k + sum(pwin**j * (1 - pwin) * memo[m - (j + 1)]
                                    for j in range(k))
        return memo[n]
    
    
    def fastpwinstreak(n, pwin, k):
        pwink = pwin**k
        memo = [0] * (n + 1)
        windowsum = 0
        for m in range(k, n + 1):
            memo[m] = pwink + windowsum
            windowsum = pwin * windowsum + (1 - pwin) * (memo[m] - pwink *
                                                         memo[m - k])
        return memo[n]
    
    
    print(avglongwinstreak(3000, 0.4))
    

    Version that allows error:

    def avglongwinstreak(n, pwin, abserr=0):
        avg = 0
        for k in range(1, n + 1):
            p = fastpwinstreak(n, pwin, k)
            avg += p
            if (n - k) * p < abserr:
                break
        return avg
    
    
    def fastpwinstreak(n, pwin, k):
        pwink = pwin**k
        memo = [0] * (n + 1)
        windowsum = 0
        for m in range(k, n + 1):
            memo[m] = pwink + windowsum
            windowsum = pwin * windowsum + (1 - pwin) * (memo[m] - pwink *
                                                         memo[m - k])
        return memo[n]
    
    
    print(avglongwinstreak(3000, 0.4, 1e-6))