I have a model that seems to run smoothly. I initialize my nuts sampler with a metropolis sample, and the metropolis sample completes without trouble. It then moves into the nuts sampler, completes all iterations, but then hangs.
I have tried:
- Restarting the kernel
- Clearing my theano cache
- Rebooting the computer and restarting the docker it runs in
- Metropolis sampling with multiple chains does complete and then terminate.
- Changing the seed or model specification slightly makes the model terminate sometimes.
But since it's not erroring out, I'm not sure how to troubleshoot. When I interrupt the process, it's always stuck in the same place. The output is pasted below. Any help diagnosing the problem would be much appreciated.
Code for my sample is here:
with my_model:
start_trace = pm.sample(7000,step=pm.Metropolis())
start_sds = {}
nms = start_trace.varnames
for i in nms:
start_sds[i]=start_trace[i].std()
with my_model:
step = pm.NUTS(scaling=my_model.dict_to_array(start_sds)**2,
is_cov=True)
signal_trace = pm.sample(500,step=step,start=start_trace[-1],njobs=3)
It finishes sampling. The first progress bar is for the metropolis sample, the second for NUTS:
100%|██████████| 7000/7000 [00:09<00:00, 718.51it/s]
100%|██████████| 500/500 [01:37<00:00, 1.47s/it]
Looking at top, there are four processes, each using about the same amount of memory, but only one of them is using a cpu. Usually, when it terminates properly, it ends about the time the other 3 processes stop using cpu. The fact that these other processes stop suggests that it has finished sampling, and the issue has something to do with terminating the multi-processing.
It sits until I interrupt, (I've left it over night), and I get the following error:
Process ForkPoolWorker-1:
KeyboardInterrupt
Process ForkPoolWorker-3:
Traceback (most recent call last):
Traceback (most recent call last):
File "/usr/lib/python3.4/multiprocessing/process.py", line 254, in _bootstrap
self.run()
File "/usr/lib/python3.4/multiprocessing/process.py", line 254, in _bootstrap
self.run()
File "/usr/lib/python3.4/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/usr/lib/python3.4/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/usr/lib/python3.4/multiprocessing/pool.py", line 108, in worker
task = get()
File "/opt/ds/lib/python3.4/site-packages/joblib/pool.py", line 362, in get
return recv()
File "/usr/lib/python3.4/multiprocessing/pool.py", line 108, in worker
task = get()
File "/opt/ds/lib/python3.4/site-packages/joblib/pool.py", line 360, in get
racquire()
File "/usr/lib/python3.4/multiprocessing/connection.py", line 250, in recv
buf = self._recv_bytes()
KeyboardInterrupt
File "/usr/lib/python3.4/multiprocessing/connection.py", line 416, in _recv_bytes
buf = self._recv(4)
File "/usr/lib/python3.4/multiprocessing/connection.py", line 383, in _recv
chunk = read(handle, remaining)
Process ForkPoolWorker-2:
File "/usr/lib/python3.4/multiprocessing/process.py", line 254, in _bootstrap
self.run()
Traceback (most recent call last):
File "/usr/lib/python3.4/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/usr/lib/python3.4/multiprocessing/pool.py", line 108, in worker
task = get()
File "/opt/ds/lib/python3.4/site-packages/joblib/pool.py", line 360, in get
racquire()
KeyboardInterrupt
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
/opt/ds/lib/python3.4/site-packages/joblib/parallel.py in retrieve(self)
698 if getattr(self._backend, 'supports_timeout', False):
--> 699 self._output.extend(job.get(timeout=self.timeout))
700 else:
/usr/lib/python3.4/multiprocessing/pool.py in get(self, timeout)
592 def get(self, timeout=None):
--> 593 self.wait(timeout)
594 if not self.ready():
/usr/lib/python3.4/multiprocessing/pool.py in wait(self, timeout)
589 def wait(self, timeout=None):
--> 590 self._event.wait(timeout)
591
/usr/lib/python3.4/threading.py in wait(self, timeout)
552 if not signaled:
--> 553 signaled = self._cond.wait(timeout)
554 return signaled
/usr/lib/python3.4/threading.py in wait(self, timeout)
289 if timeout is None:
--> 290 waiter.acquire()
291 gotit = True
KeyboardInterrupt:
During handling of the above exception, another exception occurred:
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-13-cad95f434ae5> in <module>()
87 step = pm.NUTS(scaling=my_model.dict_to_array(start_sds)**2,
88 is_cov=True)
---> 89 signal_trace = pm.sample(500,step=step,start=start_trace[-1],njobs=3)
90
91 pr = forestplot(signal_trace[-500:],
/opt/ds/lib/python3.4/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain, njobs, tune, progressbar, model, random_seed)
173 sample_func = _sample
174
--> 175 return sample_func(**sample_args)
176
177
/opt/ds/lib/python3.4/site-packages/pymc3/sampling.py in _mp_sample(**kwargs)
322 random_seed=rseed[i],
323 start=start_vals[i],
--> 324 **kwargs) for i in range(njobs))
325 return merge_traces(traces)
326
/opt/ds/lib/python3.4/site-packages/joblib/parallel.py in __call__(self, iterable)
787 # consumption.
788 self._iterating = False
--> 789 self.retrieve()
790 # Make sure that we get a last message telling us we are done
791 elapsed_time = time.time() - self._start_time
/opt/ds/lib/python3.4/site-packages/joblib/parallel.py in retrieve(self)
719 # scheduling.
720 ensure_ready = self._managed_backend
--> 721 backend.abort_everything(ensure_ready=ensure_ready)
722
723 if not isinstance(exception, TransportableException):
/opt/ds/lib/python3.4/site-packages/joblib/_parallel_backends.py in abort_everything(self, ensure_ready)
143 def abort_everything(self, ensure_ready=True):
144 """Shutdown the pool and restart a new one with the same parameters"""
--> 145 self.terminate()
146 if ensure_ready:
147 self.configure(n_jobs=self.parallel.n_jobs, parallel=self.parallel,
/opt/ds/lib/python3.4/site-packages/joblib/_parallel_backends.py in terminate(self)
321 def terminate(self):
322 """Shutdown the process or thread pool"""
--> 323 super(MultiprocessingBackend, self).terminate()
324 if self.JOBLIB_SPAWNED_PROCESS in os.environ:
325 del os.environ[self.JOBLIB_SPAWNED_PROCESS]
/opt/ds/lib/python3.4/site-packages/joblib/_parallel_backends.py in terminate(self)
134 if self._pool is not None:
135 self._pool.close()
--> 136 self._pool.terminate() # terminate does a join()
137 self._pool = None
138
/opt/ds/lib/python3.4/site-packages/joblib/pool.py in terminate(self)
604 for i in range(n_retries):
605 try:
--> 606 super(MemmapingPool, self).terminate()
607 break
608 except OSError as e:
/usr/lib/python3.4/multiprocessing/pool.py in terminate(self)
494 self._state = TERMINATE
495 self._worker_handler._state = TERMINATE
--> 496 self._terminate()
497
498 def join(self):
/usr/lib/python3.4/multiprocessing/util.py in __call__(self, wr, _finalizer_registry, sub_debug, getpid)
183 sub_debug('finalizer calling %s with args %s and kwargs %s',
184 self._callback, self._args, self._kwargs)
--> 185 res = self._callback(*self._args, **self._kwargs)
186 self._weakref = self._callback = self._args = \
187 self._kwargs = self._key = None
/usr/lib/python3.4/multiprocessing/pool.py in _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, worker_handler, task_handler, result_handler, cache)
524
525 util.debug('helping task handler/workers to finish')
--> 526 cls._help_stuff_finish(inqueue, task_handler, len(pool))
527
528 assert result_handler.is_alive() or len(cache) == 0
/usr/lib/python3.4/multiprocessing/pool.py in _help_stuff_finish(inqueue, task_handler, size)
509 # task_handler may be blocked trying to put items on inqueue
510 util.debug('removing tasks from inqueue until task handler finished')
--> 511 inqueue._rlock.acquire()
512 while task_handler.is_alive() and inqueue._reader.poll():
513 inqueue._reader.recv()
KeyboardInterrupt:
Here's the model that most consistently fails:
with pm.Model() as twitter_signal:
#location fixed effects
mu_c = pm.Flat('mu_c')
sig_c = pm.HalfCauchy('sig_c',beta=2.5)
c_raw = pm.Normal('c_raw',mu=0,sd=1,shape=n_location)
c = pm.Deterministic('c',mu_c + sig_c*c_raw)
#time fixed effects
mu_t = pm.Flat('mu_t')
sig_t = pm.HalfCauchy('sig_t',beta=2.5)
t_raw = pm.Normal('t_raw',mu=0,sd=1,shape=n_time)
t = pm.Deterministic('t',mu_t + sig_t*t_raw)
#signal effect
b_sig = pm.Normal('b_sig',0,sd=100**2,shape=1)
#control
b_control = pm.Normal('b_control',mu=0,sd=100**2,shape=1)
# define linear model and link function
#y_hat
theta =c[df.location.values] + \
t[df.dates.values] + \
(b_sig[df.c.values]* df.sig.values) + \
(b_death[df.c.values]*df.control.values)
disp = pm.HalfCauchy('disp',beta=2.5)
## Define likelihood
y = pm.NegativeBinomial('y', mu=np.exp(theta),
alpha=disp,
observed=df.loc[:,yvar])