5

How can one construct a custom dask graph using a function that requires keyword arguments that are the result of another dask task?

The dask documentation and several stackoverflow questions suggest using partial, toolz, or dask.compatibility.apply. All of these solutions work for static keyword arguments. My understanding from from Including keyword arguments (kwargs) in custom Dask graphs and some reading of the source code and debugger is that dask.compatibility.apply might be able to work with keyword arguments that are the result of a dask computation. However, I can't seem to get the syntax right nor can I find the answer elsewhere.

The example below shows a relatively simple application of dask.compatibility.apply with a dask computed keyword value. Dask successfully passes the values of the computed args 'a' and 'b', as well as the static keyword value 'other'. However, it passes the string 'c' to the function rather than replacing it with its computed value.

import dask
from dask.compatibility import apply


def custom_func(a, b, other=None, c=None):
    print(a, b, other, c)
    return a * b / c / other


dsk = {
    'a': (sum, (1, 1)),
    'b': (sum, (2, 2)),
    'c': (sum, (3, 3)),
    'd': (apply, custom_func, ['a', 'b'], {'c': 'c', 'other': 2})
}

dask.visualize(dsk, filename='graph.png')
for key in sorted(dsk):
    print(key)
    print(dask.get(dsk, key))
    print('\n')

The output is below:

a
2


b
4


c
6


d
2 4 2 c
Traceback (most recent call last):
  File "dask_kwarg.py", line 20, in <module>
    print(dask.get(dsk, key))
  File "/Users/holmgren/miniconda3/envs/pvlib36/lib/python3.6/site-packages/dask/local.py", line 562, in get_sync
    return get_async(apply_sync, 1, dsk, keys, **kwargs)
  File "/Users/holmgren/miniconda3/envs/pvlib36/lib/python3.6/site-packages/dask/local.py", line 529, in get_async
    fire_task()
  File "/Users/holmgren/miniconda3/envs/pvlib36/lib/python3.6/site-packages/dask/local.py", line 504, in fire_task
    callback=queue.put)
  File "/Users/holmgren/miniconda3/envs/pvlib36/lib/python3.6/site-packages/dask/local.py", line 551, in apply_sync
    res = func(*args, **kwds)
  File "/Users/holmgren/miniconda3/envs/pvlib36/lib/python3.6/site-packages/dask/local.py", line 295, in execute_task
    result = pack_exception(e, dumps)
  File "/Users/holmgren/miniconda3/envs/pvlib36/lib/python3.6/site-packages/dask/local.py", line 290, in execute_task
    result = _execute_task(task, data)
  File "/Users/holmgren/miniconda3/envs/pvlib36/lib/python3.6/site-packages/dask/local.py", line 271, in _execute_task
    return func(*args2)
  File "/Users/holmgren/miniconda3/envs/pvlib36/lib/python3.6/site-packages/dask/compatibility.py", line 50, in apply
    return func(*args, **kwargs)
  File "dask_kwarg.py", line 7, in custom_func
    return a * b / c / other
TypeError: unsupported operand type(s) for /: 'int' and 'str'

graph.png

Will Holmgren
  • 696
  • 5
  • 12
  • This seems specific enough that it could be considered a bug - either this is supposed to work, the documentation is unclear, or it's a feature that should be requested. I suggest you re-post to github. – mdurant Jul 09 '18 at 20:00
  • Thanks for the guidance. https://github.com/dask/dask/issues/3741 – Will Holmgren Jul 09 '18 at 21:13

1 Answers1

7

One way is to find out how dask.delayed does it :)

In [1]: import dask

In [2]: @dask.delayed
   ...: def f(*args, **kwargs):
   ...:     pass
   ...: 

In [3]: dict(f(x=1).dask)
Out[3]: 
{'f-d2cd50e7-25b1-49c5-b463-f05198b09dfb': (<function dask.compatibility.apply>,
  <function __main__.f>,
  [],
  (dict, [['x', 1]]))}

Interestingly this is also a case where the local scheduler and distributed scheduler disagree. The distributed scheduler handles this fine.

In [1]: from dask.distributed import Client

In [2]: client = Client()

In [3]: import dask
   ...: from dask.compatibility import apply
   ...: 
   ...: 
   ...: def custom_func(a, b, other=None, c=None):
   ...:     print(a, b, other, c)
   ...:     return a * b / c / other
   ...: 
   ...: 
   ...: dsk = {
   ...:     'a': (sum, (1, 1)),
   ...:     'b': (sum, (2, 2)),
   ...:     'c': (sum, (3, 3)),
   ...:     'd': (apply, custom_func, ['a', 'b'], {'c': 'c', 'other': 2})
   ...: }
   ...: 

In [4]: for key in sorted(dsk):
   ...:     print(key, client.get(dsk, key))
   ...:     
a 2
b 4
c 6
2 4 2 6
d 0.6666666666666666
MRocklin
  • 55,641
  • 23
  • 163
  • 235