I have a custom workflow, that requires using resample
to get to a higher temporal frequency, applying a ufunc
, and groupby
+ mean
to compute the final result.
I would like to apply this to a big xarray
dataset, which is backed by a chunked dask
array. For computation, I'd like to use dask.distributed
.
However, when I apply this to the full dataset, the number of tasks skyrockets, overwhelming the client and most likely also the scheduler and workers if submitted.
The xarray docs explain:
Do your spatial and temporal indexing (e.g. .sel() or .isel()) early in the pipeline, especially before calling resample() or groupby(). Grouping and rasampling triggers some computation on all the blocks, which in theory should commute with indexing, but this optimization hasn’t been implemented in dask yet.
But I really need to apply this to the full temporal axis.
So how to best implement this?
My approach was to use map_blocks
, to apply this function for each chunk individually as to keep the individual xarray
sub-datasets small enough.
This seems to work on a small scale, but when I use the full dataset, the workers run out of memory and quickly die.
Looking at the dashboard, the function I'm applying to the array gets executed multiple times of the number of chunks I have. Shouldn't these two numbers line up?
So my questions are:
- Is this approach valid?
- How could I implement this workflow otherwise, besides manually implementing the
resample
andgroupby
part and putting it in aufunc
? - Any ideas regarding the performance issues at scale (specifically the number of executions vs chunks)?
Here's a small example that mimics the workflow and shows the number of executions vs chunks:
from time import sleep
import dask
from dask.distributed import Client, LocalCluster
import numpy as np
import pandas as pd
import xarray as xr
def ufunc(x):
# computation
sleep(2)
return x
def fun(x):
# upsample to higher res
x = x.resample(time="1h").asfreq().fillna(0)
# apply function
x = xr.apply_ufunc(ufunc, x, input_core_dims=[["time"]], output_core_dims=[['time']], dask="parallelized")
# average over dates
x['time'] = x.time.dt.strftime("%Y-%m-%d")
x = x.groupby("time").mean()
return x
def create_xrds(shape):
''' helper function to create dataset'''
x,y,t = shape
tv = pd.date_range(start="1970-01-01", periods=t)
ds = xr.Dataset({
"band": xr.DataArray(
dask.array.zeros(shape, dtype="int16"),
dims=['x', 'y', 'time'],
coords={"x": np.arange(0, x), "y": np.arange(0, y), "time": tv})
})
return ds
# set up distributed
cluster = LocalCluster(n_workers=2)
client = Client(cluster)
ds = create_xrds((500,500,500)).chunk({"x": 100, "y": 100, "time": -1})
# create template
template = ds.copy()
template['time'] = template.time.dt.strftime("%Y-%m-%d")
# map fun to blocks
ds_out = xr.map_blocks(fun, ds, template=template)
# persist
ds_out.persist()
Using the example above, this is how the dask array (25 chunks) looks like:
But the function fun
gets executed 125 times: