Search code examples

How to evaluate a pymc2 model with train/test data?

I'm building a simple model in pymc2 and I want to evaluate the train data and the test data.

I tried to use this part of code

print('Accuracy on train data = {}%'.format((y.value == Y_train).mean() * 100))

but I thing the y.value is the same with Y_train, so doesn't resolve my problem.

My current code is

number_of_samples = 10000
X = np.random.randn(100, 2)
Y = np.tanh(X[:, 0] + X[:, 1])
Y = 1. / (1. + np.exp(-(Y + Y)))
Y_train = Y > 0.5

w11 = pm.Normal('w11', mu=0., tau=1.)
w12 = pm.Normal('w12', mu=0., tau=1.)
w21 = pm.Normal('w21', mu=0., tau=1.)
w22 = pm.Normal('w22', mu=0., tau=1.)
w31 = pm.Normal('w31', mu=0., tau=1.)
w32 = pm.Normal('w32', mu=0., tau=1.)

x1 = X[:, 0]
x2 = X[:, 1]

x3 = pm.Lambda('x3', lambda w1=w11, w2=w12: np.tanh(w1 * x1 + w2 * x2))
x4 = pm.Lambda('x4', lambda w1=w21, w2=w22: np.tanh(w1 * x1 + w2 * x2))

def sigmoid(x=w31 * x3 + w32 * x4):
    return 1. / (1. + np.exp(-x))

y = pm.Bernoulli('y', sigmoid, observed=True, value=Y_train)

model = pm.Model([w11, w12, w21, w22, w31, w32, y])
inference = pm.MCMC(model)

print('Accuracy on train data = {}%'.format((y.value == Y_train).mean() * 100))

And this is the network that I want to build.

my network

I expect to compute the accuracy of my trained model on train data and another test data, but is not clear for me how I can do that.


  • I think what you might want is a posterior predictive check, which you can implement by adding an additional stochastic to your model:

    y_pred = pm.Bernoulli('y_pred', sigmoid)
    model = pm.Model([w11, w12, w21, w22, w31, w32, y, y_pred])

    To get the in-sample predictions, you can use some threshold (like 0.5) to map the probabilistic prediction from the trace of y_pred to a deterministic prediction suitable for measuring accuracy:

    y_pred_samples = y_pred.trace()
    y_pred_threshold = (y_pred_samples.mean(axis=0) > .5)
    print('Accuracy on train data = {}%'.format((y_pred_threshold == Y_train).mean() * 100))

    Here is a Jupyter Notebook that puts this all together: link.