0

There is a function pymc3.traceplot() that plots the traceplots of the sampling process. I see that the function takes an argument lines that takes a dictionary, in which you can pass the means as lines to be plotted. How would you go about doing this?

Nigel Ng
  • 543
  • 1
  • 7
  • 21

1 Answers1

1

You can pass any value you want not only the mean.

theta_val = 0.35
pm.traceplot(trace, lines={'theta':theta_val})

enter image description here

theta is the name of the variable in the model and theta_val is the value you want to plot (overlap).

You can compute the mean from the trace by doing:

trace['theta'].mean()

or you can also do something like:

lines = {var:trace[var].mean() for var in trace.varnames}
aloctavodia
  • 2,040
  • 21
  • 28