MWE is below.
My code spawn processes with torch.multiprocessing.Pool
and I manage the communication with the parent with a JoinableQueue
. I followed some online guide to handle gracefully CTRL+C
. Everything works fine. In some cases (my code has more things than the MWE), though, I encounter errors in the function ran by the children (online_test()
). If that happens, the code just hangs forever, because the children do not notify the parent that something happened. I tried adding try ... except ... finally
in the main children loop, with queue.task_done()
in finally
, but nothing changed.
I need the parent to be notified about any children error and terminate everything gracefully. How could I do that? Thanks!
EDIT
Suggested solution does not work. The handler catches the exception but the main code is left hanging because it waits for the queue to be empty.
import signal
import numpy as np
import multiprocessing as mp
STOP = 'STOP'
def initializer():
"""Ignore CTRL+C in the worker process."""
signal.signal(signal.SIGINT, signal.SIG_IGN)
def error_handler(exception):
print(f'{exception} occurred, terminating pool.')
pool.terminate()
def online_test(queue):
while True:
epoch, data_id = queue.get()
if data_id == STOP:
print(f'... test function is stopping ...')
break
print(f'testing function for {data_id} has started for epoch {epoch}')
if epoch == 1:
raise NotImplementedError
queue.task_done()
if __name__ == '__main__':
mp.set_start_method('spawn')
manager = mp.Manager()
test_queue = manager.JoinableQueue()
pool = mp.Pool(initializer=initializer)
pool.apply_async(online_test,
args=(test_queue,), error_callback=error_handler)
for i in ['a', 'b', 'c']:
test_queue.put((0, i))
try:
for epoch in range(10):
print('training epoch', epoch)
print('... waiting for testing before moving on to next epoch ...')
test_queue.join()
print(f'... epoch {epoch} testing is done')
for i in ['a', 'b', 'c']:
test_queue.put((epoch + 1, i))
for i in ['a', 'b', 'c']:
test_queue.put((-1, STOP))
except KeyboardInterrupt:
pool.terminate()
else:
pool.close()
pool.join()