Search code examples
pythonbayesianpymc3

Porting pyMC2 Bayesian A/B testing example to pyMC3


I am working to learn pyMC 3 and having some trouble. Since there are limited tutorials for pyMC3 I am working from Bayesian Methods for Hackers. I'm trying to port the pyMC 2 code to pyMC 3 in the Bayesian A/B testing example, with no success. From what I can see the model isn't taking into account the observations at all.

I've had to make a few changes from the example, as pyMC 3 is quite different, so what should look like this: import pymc as pm

# The parameters are the bounds of the Uniform.
p = pm.Uniform('p', lower=0, upper=1)

# set constants
p_true = 0.05  # remember, this is unknown.
N = 1500

# sample N Bernoulli random variables from Ber(0.05).
# each random variable has a 0.05 chance of being a 1.
# this is the data-generation step
occurrences = pm.rbernoulli(p_true, N)

print occurrences  # Remember: Python treats True == 1, and False == 0
print occurrences.sum()

# Occurrences.mean is equal to n/N.
print "What is the observed frequency in Group A? %.4f" % occurrences.mean()
print "Does this equal the true frequency? %s" % (occurrences.mean() == p_true)

# include the observations, which are Bernoulli
obs = pm.Bernoulli("obs", p, value=occurrences, observed=True)

# To be explained in chapter 3
mcmc = pm.MCMC([p, obs])
mcmc.sample(18000, 1000)

figsize(12.5, 4)
plt.title("Posterior distribution of $p_A$, the true effectiveness of site A")
plt.vlines(p_true, 0, 90, linestyle="--", label="true $p_A$ (unknown)")
plt.hist(mcmc.trace("p")[:], bins=25, histtype="stepfilled", normed=True)
plt.legend()

instead looks like:

import pymc as pm

import random
import numpy as np
import matplotlib.pyplot as plt

with pm.Model() as model:
    # Prior is uniform: all cases are equally likely
    p = pm.Uniform('p', lower=0, upper=1)

    # set constants
    p_true = 0.05  # remember, this is unknown.
    N = 1500

    # sample N Bernoulli random variables from Ber(0.05).
    # each random variable has a 0.05 chance of being a 1.
    # this is the data-generation step
    occurrences = []  # pm.rbernoulli(p_true, N)
    for i in xrange(N):
        occurrences.append((random.uniform(0.0, 1.0) <= p_true))
    occurrences = np.array(occurrences)
    obs = pm.Bernoulli('obs', p_true, observed=occurrences)

    start = pm.find_MAP()
    step = pm.Metropolis()
    trace = pm.sample(18000, step, start)
    pm.traceplot(trace);
    plt.show()

Apologies for the lengthy post but in my adaptation there have been a number of small changes, e.g. manually generating the observations because pm.rbernoulli no longer exists. I'm also not sure if I should be finding the start prior to running the trace. How should I change my implementation to correctly run?


Solution

  • You were indeed close. However, this line:

    obs = pm.Bernoulli('obs', p_true, observed=occurrences)
    

    is wrong as you are just setting a constant value for p (p_true == 0.05). Thus, your random variable p defined above to have a uniform prior is not constrained by the likelihood and your plot shows that you are just sampling from the prior. If you replace p_true with p in your code it should work. Here is the fixed version:

    import pymc as pm
    
    import random
    import numpy as np
    import matplotlib.pyplot as plt
    
    with pm.Model() as model:
        # Prior is uniform: all cases are equally likely
        p = pm.Uniform('p', lower=0, upper=1)
    
        # set constants
        p_true = 0.05  # remember, this is unknown.
        N = 1500
    
        # sample N Bernoulli random variables from Ber(0.05).
        # each random variable has a 0.05 chance of being a 1.
        # this is the data-generation step
        occurrences = []  # pm.rbernoulli(p_true, N)
        for i in xrange(N):
            occurrences.append((random.uniform(0.0, 1.0) <= p_true))
        occurrences = np.array(occurrences)
        obs = pm.Bernoulli('obs', p, observed=occurrences)
    
        start = pm.find_MAP()
        step = pm.Metropolis()
        trace = pm.sample(18000, step, start)
    
    pm.traceplot(trace);