1

MWE is below.
My code spawn processes with torch.multiprocessing.Pool and I manage the communication with the parent with a JoinableQueue. I followed some online guide to handle gracefully CTRL+C. Everything works fine. In some cases (my code has more things than the MWE), though, I encounter errors in the function ran by the children (online_test()). If that happens, the code just hangs forever, because the children do not notify the parent that something happened. I tried adding try ... except ... finally in the main children loop, with queue.task_done() in finally, but nothing changed.

I need the parent to be notified about any children error and terminate everything gracefully. How could I do that? Thanks!

EDIT
Suggested solution does not work. The handler catches the exception but the main code is left hanging because it waits for the queue to be empty.

import signal
import numpy as np
import multiprocessing as mp

STOP = 'STOP'

def initializer():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)

def error_handler(exception):
    print(f'{exception} occurred, terminating pool.')
    pool.terminate()

def online_test(queue):
    while True:
        epoch, data_id = queue.get()
        if data_id == STOP:
            print(f'... test function is stopping ...')
            break
        print(f'testing function for {data_id} has started for epoch {epoch}')
        if epoch == 1:
            raise NotImplementedError
        queue.task_done()


if __name__ == '__main__':
    mp.set_start_method('spawn')
    manager = mp.Manager()
    test_queue = manager.JoinableQueue()
    pool = mp.Pool(initializer=initializer)
    pool.apply_async(online_test,
        args=(test_queue,), error_callback=error_handler)
    for i in ['a', 'b', 'c']:
        test_queue.put((0, i))

    try:
        for epoch in range(10):
            print('training epoch', epoch)
            print('... waiting for testing before moving on to next epoch ...')
            test_queue.join()
            print(f'... epoch {epoch} testing is done')
            for i in ['a', 'b', 'c']:
                test_queue.put((epoch + 1, i))

        for i in ['a', 'b', 'c']:
            test_queue.put((-1, STOP))

    except KeyboardInterrupt:
        pool.terminate()
    else:
        pool.close()

    pool.join()
Simon
  • 5,070
  • 5
  • 33
  • 59
  • To get the `error_callback` to be invoked, an unhandled exception has to occur in the process (i.e. not suppressed via some `try`/`except`). – martineau Apr 16 '22 at 21:27
  • @martineau Thanks for your help. My `online_test()` has no `try except` and raises `AttributeError` (in the line corresponding to `pred = model.value` in my MWE). The main process is left hanging. The handler is not called, because I don't see any `print(f'{exception} occurred, terminating pool.')` (as in the top answer of the suggested question). – Simon Apr 16 '22 at 23:19
  • Your "MWE" doesn't seem that minimal to me because it's still doing lots of things that appear to be related to the problem. Could you post literally the shortest code possible that will reproduce the problem? Similar to what I did in my answer to the linked question. – martineau Apr 16 '22 at 23:29
  • @martineau You are right, my MWE was not really good. I updated it. In the new MWE the handler catches the exception but the main code is left hanging at epoch 1 because it waits for the queue to be empty. I tried to define the handler within the main function and to have it call `test_queue.task_done()` for the queue size, but then the code hangs at epoch 2. – Simon Apr 16 '22 at 23:55

1 Answers1

1

I got the code in your EDIT to work by doing a these two things in the error handler function:

  1. Emptied the test_queue.
  2. Set a global flag variable named aborted to true to indicate processing should stop.

Then in the __main__ process I added code to check the aborted flag before waiting for the previous epoch to finish and starting another.

Using a global seems a little hacky, but it works because the error handler function is executed as part of the main process, so has access to its globals. I remember when that detail dawned on me when I was working on the linked answer — and as you can see — it can prove to important/useful.

import signal
import numpy as np
import multiprocessing as mp

STOP = 'STOP'

def initializer():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)

def error_handler(exception):
    print(f'{exception=} occurred, terminating pool.')
    pool.terminate()
    print('pool terminated.')
    while not test_queue.empty():
        try:
            test_queue.task_done()
        except ValueError:
            break
    print(f'test_queue cleaned.')
    global aborted
    aborted = True  # Indicate an error occurred to the main process.

def online_test(queue):
    while True:
        epoch, data_id = queue.get()
        if data_id == STOP:
            print(f'... test function is stopping ...')
            break
        print(f'testing function for {data_id} has started for epoch {epoch}')
        if epoch == 1:
            raise NotImplementedError('epoch == 1')  # Fake error for testing.
        queue.task_done()


if __name__ == '__main__':
    aborted = False
    mp.set_start_method('spawn')
    manager = mp.Manager()
    test_queue = manager.JoinableQueue()
    pool = mp.Pool(initializer=initializer)
    pool.apply_async(online_test,  args=(test_queue,), error_callback=error_handler)
    for i in ['a', 'b', 'c']:
        test_queue.put((0, i))

    try:
        for epoch in range(10):
            if aborted:  # Error occurred?
                print('ABORTED by error_handler!')
                break
            print('training epoch', epoch)
            print('... waiting for testing before moving on to next epoch ...')
            test_queue.join()
            print(f'... epoch {epoch} testing is done')
            for i in ['a', 'b', 'c']:
                test_queue.put((epoch + 1, i))

        for i in ['a', 'b', 'c']:
            test_queue.put((-1, STOP))

    except KeyboardInterrupt:
        pool.terminate()
    else:
        pool.close()

    pool.join()

Output from sample run:

training epoch 0
... waiting for testing before moving on to next epoch ...
testing function for a has started for epoch 0
testing function for b has started for epoch 0
testing function for c has started for epoch 0
... epoch 0 testing is done
testing function for a has started for epoch 1
training epoch 1
... waiting for testing before moving on to next epoch ...
exception=NotImplementedError('epoch == 1') occurred, terminating pool.
pool terminated.
... epoch 1 testing is done
test_queue cleaned.
ABORTED by error_handler!
Press any key to continue . . .
martineau
  • 119,623
  • 25
  • 170
  • 301