3

I'm trying to figure out the best way to map a dask Series with a large mapping. The straightforward series.map(large_mapping) issues UserWarning: Large object of size <X> MB detected in task graph and suggests using client.scatter and client.submit but the latter doesn't solve the problem and in fact it's much slower. Trying broadcast=True in client.scatter doesn't help either.

import argparse
import distributed
import dask.dataframe as dd

import numpy as np
import pandas as pd


def compute(s_size, m_size, npartitions, scatter, broadcast, missing_percent=0.1, seed=1):
    np.random.seed(seed)
    mapping = dict(zip(np.arange(m_size), np.random.random(size=m_size)))
    ps = pd.Series(np.random.randint((1 + missing_percent) * m_size, size=s_size))
    ds = dd.from_pandas(ps, npartitions=npartitions)
    if scatter:
        mapping_futures = client.scatter(mapping, broadcast=broadcast)
        future = client.submit(ds.map, mapping_futures)
        return future.result()
    else:
        return ds.map(mapping)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', default=200000, type=int, help='series size')
    parser.add_argument('-m', default=50000, type=int, help='mapping size')
    parser.add_argument('-p', default=10, type=int, help='partitions number')
    parser.add_argument('--scatter', action='store_true', help='Scatter mapping')
    parser.add_argument('--broadcast', action='store_true', help='Broadcast mapping')
    args = parser.parse_args()

    client = distributed.Client()
    ds = compute(args.s, args.m, args.p, args.scatter, args.broadcast)
    print(ds.compute().describe())
gsakkis
  • 1,569
  • 1
  • 15
  • 24

1 Answers1

2

You problem is here

In [4]: mapping = dict(zip(np.arange(50000), np.random.random(size=50000)))

In [5]: import pickle

In [6]: %time len(pickle.dumps(mapping))
CPU times: user 2.24 s, sys: 18.6 ms, total: 2.26 s
Wall time: 2.25 s
Out[6]: 6268809

So mapping is big and unpartitioned - the scatter operation is the one giving you the problem in this case.

Consider the alternative

def make_mapping():
    return dict(zip(np.arange(50000), np.random.random(size=50000)))

mapping = client.submit(make_mapping)  # ships the function, not the data
                                       # and requires no serialisation
future = client.submit(ds.map, mapping)

This will not show the warning. However, it seems strange to me to use a dictionary here to do the mapping, a series of straight array seems to encode the nature of the data better.

mdurant
  • 27,272
  • 5
  • 45
  • 74
  • Thanks, I ended up doing something similar. The posted sample code was just for illustration, in the actual use case the mapping computation is more complex and expensive to recompute in every worker. So instead I compute it once, pickle it and then have the workers unpickle it. An extra optimization I discovered was the use of `distributed.worker.thread_state` to cache the unpickled mapping for all multithreaded workers in the same process. – gsakkis Jun 07 '18 at 08:04