11

I'm having a lot of success using Dask and Distributed to develop data analysis pipelines. One thing that I'm still looking forward to improving, however, is the way I handle exceptions.

Right now if, I write the following

def my_function (value):
    return 1 / value

results = (dask.bag
    .from_sequence(range(-10, 10))
    .map(my_function))

print(results.compute())

... then on running the program I get a long, long list of tracebacks (one per worker, I'm guessing). The most relevant segment being

distributed.utils - ERROR - division by zero
Traceback (most recent call last):
  File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/distributed/utils.py", line 193, in f
    result[0] = yield gen.maybe_future(func(*args, **kwargs))
  File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/gen.py", line 1015, in run
    value = future.result()
  File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/concurrent.py", line 237, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 3, in raise_exc_info
  File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/gen.py", line 1021, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/distributed/client.py", line 1473, in _get
    result = yield self._gather(packed)
  File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/gen.py", line 1015, in run
    value = future.result()
  File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/concurrent.py", line 237, in result
    raise_exc_info(self._exc_info)
  File "<string>", line 3, in raise_exc_info
  File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/tornado/gen.py", line 1021, in run
    yielded = self.gen.throw(*exc_info)
  File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/distributed/client.py", line 923, in _gather
    st.traceback)
  File "/Users/ajmazurie/test/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/six.py", line 685, in reraise
    raise value.with_traceback(tb)
  File "/mnt/lustrefs/work/aurelien.mazurie/test_dask/.env/pyenv-3.6.0-default/lib/python3.6/site-packages/dask/bag/core.py", line 1411, in reify
  File "test.py", line 9, in my_function
    return 1 / value
ZeroDivisionError: division by zero

Here, of course, a visual inspection will tell me that the error was dividing a number by zero. What I'm wondering is if there is a better way to track these errors. For example, I cannot seem to be able to catch the exception itself:

import dask.bag
import distributed

try:
    dask_scheduler = "127.0.0.1:8786"
    dask_client = distributed.Client(dask_scheduler)

    def my_function (value):
        return 1 / value

    results = (dask.bag
        .from_sequence(range(-10, 10))
        .map(my_function))

    #dask_client.persist(results)

    print(results.compute())

except Exception as e:
    print("error: %s" % e)

EDIT: Note that in my example I'm using distributed, not just dask. There is a dask-scheduler listening on port 8786 with four dask-worker processes registered to it.

This code will produce the exact same output as above, meaning that I'm not actually catching the exception with my try/except block.

Now, since we're talking of distributed tasks across a cluster it is obviously non trivial to propagate exceptions back to me. Is there any guideline to do so? Right now my solution is to have functions return both a result and an optional error message, then process the results and error messages separately:

def my_function (value):
    try:
        return {"result": 1 / value, "error": None}
    except ZeroDivisionError:
        return {"result": None, "error": "boom!"}

results = (dask.bag
    .from_sequence(range(-10, 10))
    .map(my_function))

dask_client.persist(results)

errors = (results
    .pluck("error")
    .filter(lambda x: x is not None)
    .compute())

print(errors)

results = (results
    .pluck("result")
    .filter(lambda x: x is not None)
    .compute())

print(results)

This works, but I'm wondering if I'm sandblasting the soup cracker here. EDIT: Another option would be to use something like a Maybe monad, but once again I'd like to know if I'm overthinking it.

ajmazurie
  • 509
  • 4
  • 8

2 Answers2

1

Dask automatically packages up exceptions that occurred remotely and reraises them locally. Here is what I get when I run your example

In [1]: from dask.distributed import Client

In [2]: client = Client('localhost:8786')

In [3]: import dask.bag

In [4]: try:
   ...:     def my_function (value):
   ...:         return 1 / value
   ...: 
   ...:     results = (dask.bag
   ...:         .from_sequence(range(-10, 10))
   ...:         .map(my_function))
   ...: 
   ...:     print(results.compute())
   ...: 
   ...: except Exception as e:
   ...:     import pdb; pdb.set_trace()
   ...:     print("error: %s" % e)
   ...:     
distributed.utils - ERROR - division by zero
> <ipython-input-4-17aa5fbfb732>(13)<module>()
-> print("error: %s" % e)
(Pdb) pp e
ZeroDivisionError('division by zero',)
MRocklin
  • 55,641
  • 23
  • 163
  • 235
  • 1
    That is intriguing; this is not the behavior I get running the exact same code, with the (key?) difference being that I'm using `distributed` and use a `Client` object to submit my jobs. Could it be that the behavior of reraising exception is handled in **Dask** only, but not in **distributed**? As an additional information, the workers are on a remote compute cluster. – ajmazurie Mar 01 '17 at 16:30
  • 3
    I've updated the answer to include connecting to a distributed scheduler. The result is the same. – MRocklin Mar 01 '17 at 19:52
1

You could wrap your function like so:

def exception_handler(orig_func):
  def wrapper(*args,**kwargs):
    try:
      return orig_func(*args,**kwargs)
    except:
      import sys
      sys.exit(1)
  return wrapper

You could use a decorator or do:

wrapped = exception_handler(my_function)
dask_client.map(wrapper, range(100))

This seems to automatically rebalance tasks if a worker fails. But I don't know how to remove the failed worker from the pool.

billiam
  • 132
  • 1
  • 15