4

According to What threads do Dask Workers have active?, a dask worker has

A pool of threads in which to run tasks.

The documentation says

If your computations are mostly numeric in nature (for example NumPy and Pandas computations) and release the GIL entirely then it is advisable to run dask-worker processes with many threads and one process. This reduces communication costs and generally simplifies deployment.

The internals of NumPy use MKL or OpenBLAS, with a number of threads equal to the env variables OPENBLAS_NUM_THREADS or MKL_NUM_THREADS when the code is normally executed.

How do those parameters and dask's computation threads work together?

Labo
  • 2,482
  • 2
  • 18
  • 38

2 Answers2

3

Short answer

Poorly

Longer answer

By default most modern BLAS/LAPACK implementations use as many threads as you have logical cores. Dask will do the same (assuming default configuration). If you're doing L3 BLAS operations then this can result in many more active threads than you have cores, and a general degredation of performance.

I typically set XXX_NUM_THREADS=1 and rely on Dask for parallelism when using both together.

MRocklin
  • 55,641
  • 23
  • 163
  • 235
  • If I understand, by default each task uses all the BLAS threads it can for parallelism but you suggest to set up a variable to use only one? – Labo Oct 20 '18 at 09:19
  • Every Dask task runs in a separate thread. Dask makes not constraints around what that task/function does. If it uses many threads itself then things can become inefficient. You can constrain most BLAS implementations by setting environment variables as you suggest in your question. – MRocklin Oct 21 '18 at 13:35
1

Not answering how dask threads interact with BLAS, but following up on MRocklin's answer, threadpoolctl provides a nice interface to controlling the number of threads used by BLAS, and seems to work fine with dask workers. You can try:

import dask.array as da
from threadpoolctl import threadpool_limits

x = da.random.random((1000000,2000), chunks=(5000, -1))
xtx = x.T @ x

with threadpool_limits(limits=1, user_api='blas'):
    xtx.compute()

For me it is about 15-20% faster when using the limiter.

thomaskeefe
  • 1,900
  • 18
  • 19