There is a function pymc3.traceplot()
that plots the traceplots of the sampling process. I see that the function takes an argument lines
that takes a dictionary, in which you can pass the means as lines to be plotted. How would you go about doing this?
You can pass any value you want not only the mean.
theta_val = 0.35
pm.traceplot(trace, lines={'theta':theta_val})
theta
is the name of the variable in the model and theta_val
is the value you want to plot (overlap).
You can compute the mean from the trace by doing:
trace['theta'].mean()
or you can also do something like:
lines = {var:trace[var].mean() for var in trace.varnames}