When I try to give initial start values in for the standard deviations of LKJCholeskyCov
I get a bad initial energy
error in pymc3
.
Below, the first code runs fine. But the second will give you that error. I believe I am doing something wrong in setting up the start values array.
Working example:
pip install arviz==0.11
pip install pymc3==3.11.1
# Working example:
import pymc3 as pm
n_samples = 20
n_tune_samples = 10
mu = np.zeros(3)
true_cov = np.array([[1.0, 0.5, 0.1],
[0.5, 2.0, 0.2],
[0.1, 0.2, 1.0]])
data = np.random.multivariate_normal(mu, true_cov, 10)
print(data.shape)
with pm.Model() as model1:
sd_dist = pm.Exponential.dist(1.0, shape=3)
print(sd_dist.shape)
chol, corr, stds = pm.LKJCholeskyCov('chol_cov', n=3, eta=2,
sd_dist=sd_dist, compute_corr=True)
vals = pm.MvNormal('vals', mu=mu, chol=chol, observed=data)
with model1:
trace1 = pm.sample(draws=n_samples, tune=n_tune_samples)
Not working example:
# Not working example:
import pymc3 as pm
n_samples = 20
n_tune_samples = 10
mu = np.zeros(3)
true_cov = np.array([[1.0, 0.5, 0.1],
[0.5, 2.0, 0.2],
[0.1, 0.2, 1.0]])
data = np.random.multivariate_normal(mu, true_cov, 10)
print(data.shape)
with pm.Model() as model2:
sd_dist = pm.Exponential.dist(1.0, shape=3)
print(sd_dist.shape)
chol, corr, stds = pm.LKJCholeskyCov('chol_cov', n=3, eta=2,
sd_dist=sd_dist, compute_corr=True)
vals = pm.MvNormal('vals', mu=mu, chol=chol, observed=data)
with model2:
chol_init = np.diag([0.56, 0.61, 0.74])[np.tril_indices(3)]
trace2 = pm.sample(draws=n_samples, tune=n_tune_samples,
start={'chol_cov':chol_init})
The error message:
Only 20 samples in chain.
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [chol_cov]
0.00% [0/120 00:00<00:00 Sampling 4 chains, 0 divergences]
Bad initial energy, check any log probabilities that are inf or -inf, nan or very small:
Series([], )
---------------------------------------------------------------------------
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "C:\Users\ilyas\Anaconda3\envs\pymc\lib\site-packages\pymc3\parallel_sampling.py", line 191, in _start_loop
point, stats = self._compute_point()
File "C:\Users\ilyas\Anaconda3\envs\pymc\lib\site-packages\pymc3\parallel_sampling.py", line 216, in _compute_point
point, stats = self._step_method.step(self._point)
File "C:\Users\ilyas\Anaconda3\envs\pymc\lib\site-packages\pymc3\step_methods\arraystep.py", line 276, in step
apoint, stats = self.astep(array)
File "C:\Users\ilyas\Anaconda3\envs\pymc\lib\site-packages\pymc3\step_methods\hmc\base_hmc.py", line 159, in astep
raise SamplingError("Bad initial energy")
pymc3.exceptions.SamplingError: Bad initial energy
"""
The above exception was the direct cause of the following exception:
SamplingError Traceback (most recent call last)
SamplingError: Bad initial energy
The above exception was the direct cause of the following exception:
ParallelSamplingError Traceback (most recent call last)
<ipython-input-412-6bc0303342ea> in <module>
21 chol_init = np.diag([0.56, 0.61, 0.74])[np.tril_indices(3)]
22
---> 23 trace2 = pm.sample(draws=n_samples, tune=n_tune_samples,
24 start={'chol_cov':chol_init})
~\Anaconda3\envs\pymc\lib\site-packages\pymc3\sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
556 _print_step_hierarchy(step)
557 try:
--> 558 trace = _mp_sample(**sample_args, **parallel_args)
559 except pickle.PickleError:
560 _log.warning("Could not pickle model, sampling singlethreaded.")
~\Anaconda3\envs\pymc\lib\site-packages\pymc3\sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, pickle_backend, **kwargs)
1474 try:
1475 with sampler:
-> 1476 for draw in sampler:
1477 trace = traces[draw.chain - chain]
1478 if trace.supports_sampler_stats and draw.stats is not None:
~\Anaconda3\envs\pymc\lib\site-packages\pymc3\parallel_sampling.py in __iter__(self)
477
478 while self._active:
--> 479 draw = ProcessAdapter.recv_draw(self._active)
480 proc, is_last, draw, tuning, stats, warns = draw
481 self._total_draws += 1
~\Anaconda3\envs\pymc\lib\site-packages\pymc3\parallel_sampling.py in recv_draw(processes, timeout)
357 else:
358 error = RuntimeError("Chain %s failed." % proc.chain)
--> 359 raise error from old_error
360 elif msg[0] == "writing_done":
361 proc._readable = True
ParallelSamplingError: Bad initial energy
Any help is appreciated.
ilyas