Search code examples
pythonpymc3pymc

PyMC3 Normal with variance per column


I am trying to define a pymc3.Normal variable with the following as mu:

import numpy as np
import pymc3 as pm

mx = np.array([[0.25 , 0.5  , 0.75 , 1.   ],    
               [0.25 , 0.333, 0.25 , 0.   ],
               [0.25 , 0.167, 0.   , 0.   ],
               [0.25 , 0.   , 0.   , 0.   ]])
epsilon = pm.Gamma('epsilon', alpha=10, beta=10)
p_ = pm.Normal('p_', mu=mx, shape = mx.shape, sd = epsilon)

The problem is that all random variables in p_ get the same std (epsilon). I would like the first row to use epsilon1, the second row epsilon2 etc.

How Can I do that?


Solution

  • One can pass an argument for the shape parameter to achieve this. To demonstrate this, let's make some fake data to pass as observed, where we use fixed values for epsilon that we can compare against the inferred ones.

    Example Model

    import numpy as np
    import pymc3 as pm
    import arviz as az
    
    # priors
    mu = np.array([[0.25 , 0.5  , 0.75 , 1.   ],    
                   [0.25 , 0.333, 0.25 , 0.   ],
                   [0.25 , 0.167, 0.   , 0.   ],
                   [0.25 , 0.   , 0.   , 0.   ]])
    alpha, beta = 10, 10
    
    # fake data
    np.random.seed(2019)
    
    # row vector will use a different sd for each column
    sd = np.random.gamma(alpha, 1.0/beta, size=(1,4))
    
    # generate 100 fake observations of the (4,4) random variables
    Y = np.random.normal(loc=mu, scale=sd, size=(100,4,4))
    
    # true column sd's
    print(sd)
    # [[0.90055471 1.24522079 0.85846659 1.19588367]]
    
    # mean sd's per column
    print(np.mean(np.std(Y, 0), 0))
    # [0.92028042 1.24437592 0.83383181 1.22717313]
    
    # model
    with pm.Model() as model:
        # use a (1,4) matrix to pool variance by columns
        epsilon = pm.Gamma('epsilon', alpha=10, beta=10, shape=(1, mu.shape[1]))
    
        p_ = pm.Normal('p_', mu=mu, sd=epsilon, shape=mu.shape, observed=Y)
    
        trace = pm.sample(random_seed=2019)
    

    This samples well, and gives the following summary

    enter image description here

    which clearly bound the true values of the standard deviations within the HPDs.