I'm trying to persist 1.5 million images to a dask cluster as a dask array, and then get some summary stats. I'm following an image processing tutorial from @mrocklin's blog and have edited my script to be a minimally reproducible example:
import time
import dask
import dask.array as da
import numpy as np
from distributed import Client
client = Client()
def get_imgs(num_imgs):
def get():
arr = np.random.randint(2000, size=(3, 120, 120)).flatten()
return arr
delayed_get = dask.delayed(get)
return [da.from_delayed(delayed_get(), shape=(3 * 120 * 120,), dtype=np.uint16) for num in range(num_imgs)]
imgs = get_imgs(1500000)
imgs = da.stack(imgs, axis=0)
client.persist(imgs)
The persist
step causes my jupyter process to crash. Is that because the persist
step causes a bunch of operations to be done on each object in the collection, and the collection is too large to fit in memory? So I use scatter
instead:
start = time.time()
imgs_future = client.scatter(imgs, broadcast=True)
print(time.time() - start)
But the jupyter process crashes, or the network connection to the scheduler gets lost.
So I tried breaking up the scatter
step:
st = time.time()
chunk_size = 50000
chunk_num = 0
chunk_futures = []
start = 0
end = start + chunk_size
is_last_chunk = False
for dataset in client.list_datasets():
client.unpublish_dataset(dataset)
while True:
cst = time.time()
chunk = imgs[start:end]
cst1 = time.time()
if start == 0:
print('loaded chunk in', cst1 - cst)
if len(chunk) == 0:
break
chunk_future = client.scatter(chunk)
chunk_futures.append(chunk_future)
dataset_name = "chunk_{}".format(chunk_num)
client.publish_dataset(**{dataset_name: chunk_future})
if start == 0:
print('submitted chunk in', time.time() - cst1)
start = end
if is_last_chunk:
break
chunk_num += 1
end = start + chunk_size
if end > len(image_paths_to_submit):
is_last_chunk = True
end = len(image_paths_to_submit)
if start == end:
break
if chunk_num % 5 == 0:
print('chunk_num', chunk_num, 'start', start)
print('completed in', time.time() - st)
But this approach results in the connection being lost as well. What's the recommended approach to persisting a large image dataset in a cluster in an asynchronous way?
I've looked at the delayed best practices and what jumps out at me is that I may be using too many tasks? So maybe I need to do more batching in each get()
call.