3

I have a dataset in which a few elements, which are close to each other and generally end up in the same partition, cause more computation than others, because they have quadratic complexity. I want to randomly reshuffle them so that the workload ends up distributed more or less equally across partitions and I avoid having to do all the computation in a single partition.

Right now I'm downloading everything in the coordinator with a piece of code like this:

import dask.bag as db
import random

bag = ...
l = bag.compute()
random.shuffle(l)
bag = db.from_sequence(l)

Is there a way to do it in a more distributed way? I tried, for example, by repartitioning based on a random key, but I end up with most partitions empty.

della
  • 151
  • 3

1 Answers1

2

One solution can be to shuffle for each partition of the bag. The problem is that you only shuffle independently for each partition.

import random

import dask.bag as db
import matplotlib.pyplot as plt

# we can't directly use shuffle because it does inplace
def shuffle(x):
    """shuffle and return x"""
    random.shuffle(x)
    return x

bag = db.from_sequence(list(range(2000)), npartitions=4)
# we apply the shuffle to each partition
result = bag.map_partitions(shuffle).compute()
print(result[:10], result[-10:])
# [233, 204, 181, 18, 50, 114, 424, 6, 195, 348] [1910, 1623, 1730, 1552, 1754, 1899, 1659, 1946, 1834, 1551]

plt.scatter(result, range(len(result)))

enter image description here

As you can see, it just suffled on each partition. But since bag uses multiprocessing and we are not sharing memory between partitions this should be pretty fast.

Another way, if you can use an dask.array instead of dask.bag, is with shuffle_slice. This does give a more uniform result.

import random

import numpy as np
import dask.array as da
from dask.array.slicing import shuffle_slice
import matplotlib.pyplot as plt

array = da.from_array(np.arange(2000))
shuffle_index = np.arange(2000)
np.random.shuffle(shuffle_index)

array = shuffle_slice(array, shuffle_index)
result = array.compute()
print(result[:10], result[-10:])

plt.scatter(result, range(len(result)))

enter image description here

And if you can use a dask.dataframe maybe it's easier to do it with random_split.

Lucas
  • 6,869
  • 5
  • 29
  • 44
  • Hey, thanks for the answer! Unfortunately the first solution doesn't work for me--I want to split work evenly between partitions and in-partition shuffling doesn't do it; moreover, it's not clear to me how to convert my bag--containing (int, int, int, object) tuples to an array nor how to use random_split to that extent. Maybe you could give me some more help? – della May 11 '21 at 18:39
  • I have no more ideas on how to fix it. Maybe someone with more knowledge of dask can answer you – Lucas May 11 '21 at 19:16