Search code examples
pythonpymc3

How are pymc3 variables assigned to the currently active model?


In PyMC3 you can do this

basic_model = pm.Model()

with basic_model:

    # Priors for unknown model parameters
    alpha = pm.Normal('alpha', mu=0, sd=10)
    beta = pm.Normal('beta', mu=0, sd=10, shape=2)
    sigma = pm.HalfNormal('sigma', sd=1)

    # Expected value of outcome
    mu = alpha + beta[0]*X1 + beta[1]*X2

    # Likelihood (sampling distribution) of observations
    Y_obs = pm.Normal('Y_obs', mu=mu, sd=sigma, observed=Y)

and all the variables (pm.Normal, ...) will be "assigned" to the basic_model instance.

From the docs

The first line,

basic_model = Model()

creates a new Model object which is a container for the model random variables.

Following instantiation of the model, the subsequent specification of the model components is performed inside a with statement:

with basic_model:

This creates a context manager, with our basic_model as the context, that includes all statements until the indented block ends. This means all PyMC3 objects introduced in the indented code block below the with statement are added to the model behind the scenes. Absent this context manager idiom, we would be forced to manually associate each of the variables with basic_model right after we create them. If you try to create a new random variable without a with model: statement, it will raise an error since there is no obvious model for the variable to be added to.

I think it is very elegant for the purpose of the library. How is that actually implemented?

The only way I can think of is something in the spirit of this:

class Model:
    active_model = None
    def __enter__(self):
        Model.active_model = self
    def __exit__(self, *args, **kwargs):
        Model.active_model = None

class Normal:
    def __init__(self):
        if Model.active_model is None:
            raise ValueError("cant instantiate variable outside of Model")
        else:
            self.model = Model.active_model

It works with my simple REPL tests but I am not sure if this have some pitfalls and is actually that simple.


Solution

  • You are very close, and it was even quite similar to your implementation for a while. Note that threading.local is used to store objects, and it is maintained as a list to allow nesting multiple models, and allows multiprocessing. There is a little extra in the actual implementation to allow setting theano configuration when entering a model context that I deleted:

    class Context(object):
        contexts = threading.local()
    
        def __enter__(self):
            type(self).get_contexts().append(self)
            return self
    
        def __exit__(self, typ, value, traceback):
            type(self).get_contexts().pop()
    
        @classmethod
        def get_contexts(cls):
            if not hasattr(cls.contexts, 'stack'):
                cls.contexts.stack = []
            return cls.contexts.stack
    
        @classmethod
        def get_context(cls):
            """Return the deepest context on the stack."""
            try:
                return cls.get_contexts()[-1]
            except IndexError:
                raise TypeError("No context on context stack")
    

    The Model class subclasses Context, so when writing algorithms we can call Model.get_context() from inside a context manager and have access to the object. This is equivalent to your Model.active_model.