I am trying to restructure my code to use Dask instead of NumPy for large array computations. However, I am struggling with the runtime performance of Dask:
In[15]: import numpy as np
In[16]: import dask.array as da
In[17]: np_arr = np.random.rand(10, 10000, 10000)
In[18]: da_arr = da.from_array(np_arr, chunks=(-1, 'auto', 'auto'))
In[19]: %timeit np.mean(np_arr, axis=0)
1 loop, best of 3: 2.59 s per loop
In[20]: %timeit da_arr.mean(axis=0).compute()
1 loop, best of 3: 4.23 s per loop
I had a look at similar questions (why is dot product in dask slower than in numpy), but playing around with the chunk size did not help. I will mainly use arrays having approximately the same size as above. Is it recommendable to use NumPy instead of Dask for such arrays or can I tune something? I have also tried to use the Client
from dask.distributed
and started it with 16 processes and 4 threads per process (16 core CPU), but this made it even worse.
Thanks in advance!
EDIT:
I have played a bit around with Dask and distributed processing. The data transfer (dumping of array and retrieval of result) seems to be the major limitation/issue, whereas computation is really fast (436ms compared to 9.51s). But even for client.compute()
, the wall time is larger (12.1s) than for do_stuff(data)
. Can this and the data transfer in general somehow be improved?
In[3]: import numpy as np
In[4]: from dask.distributed import Client, wait
In[5]: from dask import delayed
In[6]: import dask.array as da
In[7]: client = Client('address:port')
In[8]: client
Out[8]: <Client: scheduler='tcp://address:port' processes=4 cores=16>
In[9]: data = np.random.rand(400, 100, 10000)
In[10]: %time [future] = client.scatter([data])
CPU times: user 8.36 s, sys: 5.08 s, total: 13.4 s
Wall time: 24.5 s
In[11]: x = da.from_delayed(delayed(future), shape=data.shape, dtype=data.dtype)
In[12]: x = x.rechunk(chunks=('auto', 'auto', 'auto'))
In[13]: x = client.persist(x)
In[14]: {w: len(keys) for w, keys in client.has_what().items()}
Out[14]:
{'tcp://address:port': 65,
'tcp://address:port': 0,
'tcp://address:port': 0,
'tcp://address:port': 0}
In[15]: client.rebalance(x)
In[16]: {w: len(keys) for w, keys in client.has_what().items()}
Out[16]:
{'tcp://address:port': 17,
'tcp://address:port': 16,
'tcp://address:port': 16,
'tcp://address:port': 16}
In[17]: def do_stuff(arr):
... arr = arr/3. + arr**2 - arr**(1/2)
... arr[arr >= 0.5] = 1
... return arr
...
In[18]: %time future_compute = client.compute(do_stuff(x)); wait(future_compute)
Matplotlib support failed
CPU times: user 387 ms, sys: 49.5 ms, total: 436 ms
Wall time: 12.1 s
In[19]: future_compute
Out[19]: <Future: status: finished, type: ndarray, key: finalize-54eb04bbe03eee8af686fd43b41eb161>
In[21]: %timeit future_compute.result()
1 loop, best of 3: 19.4 s per loop
In[21]: %time do_stuff(data)
CPU times: user 4.49 s, sys: 5.02 s, total: 9.51 s
Wall time: 9.5 s