3

I have a custom workflow, that requires using resample to get to a higher temporal frequency, applying a ufunc, and groupby + mean to compute the final result.

I would like to apply this to a big xarray dataset, which is backed by a chunked dask array. For computation, I'd like to use dask.distributed.

However, when I apply this to the full dataset, the number of tasks skyrockets, overwhelming the client and most likely also the scheduler and workers if submitted.

The xarray docs explain:

Do your spatial and temporal indexing (e.g. .sel() or .isel()) early in the pipeline, especially before calling resample() or groupby(). Grouping and rasampling triggers some computation on all the blocks, which in theory should commute with indexing, but this optimization hasn’t been implemented in dask yet.

But I really need to apply this to the full temporal axis.

So how to best implement this?

My approach was to use map_blocks, to apply this function for each chunk individually as to keep the individual xarray sub-datasets small enough.

This seems to work on a small scale, but when I use the full dataset, the workers run out of memory and quickly die.

Looking at the dashboard, the function I'm applying to the array gets executed multiple times of the number of chunks I have. Shouldn't these two numbers line up?

So my questions are:

  • Is this approach valid?
  • How could I implement this workflow otherwise, besides manually implementing the resample and groupby part and putting it in a ufunc?
  • Any ideas regarding the performance issues at scale (specifically the number of executions vs chunks)?

Here's a small example that mimics the workflow and shows the number of executions vs chunks:

from time import sleep

import dask
from dask.distributed import Client, LocalCluster
import numpy as np
import pandas as pd
import xarray as xr

def ufunc(x):
    # computation
    sleep(2)
    return x

def fun(x):
    # upsample to higher res
    x = x.resample(time="1h").asfreq().fillna(0)
    
    # apply function
    x = xr.apply_ufunc(ufunc, x, input_core_dims=[["time"]], output_core_dims=[['time']], dask="parallelized")
    
    # average over dates
    x['time'] = x.time.dt.strftime("%Y-%m-%d")
    x = x.groupby("time").mean()

    return x

def create_xrds(shape):
    ''' helper function to create dataset'''
    x,y,t = shape

    tv = pd.date_range(start="1970-01-01", periods=t)
    
    ds = xr.Dataset({
    "band": xr.DataArray(
        dask.array.zeros(shape, dtype="int16"),
        dims=['x', 'y', 'time'],
        coords={"x": np.arange(0, x), "y": np.arange(0, y), "time": tv})
    })
    
    return ds


# set up distributed

cluster = LocalCluster(n_workers=2)
client = Client(cluster)

ds = create_xrds((500,500,500)).chunk({"x": 100, "y": 100, "time": -1})

# create template

template = ds.copy()
template['time'] = template.time.dt.strftime("%Y-%m-%d")

# map fun to blocks
ds_out = xr.map_blocks(fun, ds, template=template)

# persist

ds_out.persist()

Using the example above, this is how the dask array (25 chunks) looks like:

enter image description here

But the function fun gets executed 125 times:

enter image description here

Val
  • 6,585
  • 5
  • 22
  • 52

1 Answers1

0

Looking at the dashboard, the function I'm applying to the array gets executed multiple times of the number of chunks I have. Shouldn't these two numbers line up?

This is misleading because of an unfortunate choice made when making the graph. The number includes tasks that make a block of the input Dataset (one per variable per chunk) & for the output Dataset as well as tasks that apply the function. This will get fixed soon (https://github.com/pydata/xarray/pull/5007)

dcherian
  • 186
  • 1
  • 1