0

I can't calculate parcel_profile (1D function) on an xarray dataset (segmented with dask) in 4D.

Hello,

I really need help, I'm working on ERA5 hourly data on pressure levels. I've extracted relative humidity and temperature on several atmospheric pressure levels. I use Metpy's 'dewpoint_from_relative_humidity' function to calculate dewpoint. The table was quite large, so I use dask to create several chunks. I have the following table:

import numpy as np
import metpy as mp
import metpy.calc as mpcalc
import xarray as xr
import metpy.units as mpunit

with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ncin_1 = xr.open_dataset(ncfname_1, engine='netcdf4').chunk('auto')
    ncin_1['t'] = ncin_1['t'] - 273.15
    ncin_1['r'] = ncin_1['r'] / 100.0
    ncin_1['r'] = ncin_1['r'].clip(min=0.01, max=1)

# variable reprocessing with metpy unit registry
ncin_1['level'].attrs['units'] = 'hPa'
ncin_1 = ncin_1.metpy.quantify()
ncin_1['time']  = ncin_1['time'].metpy.time 
ncin_1['latitude']  = ncin_1['latitude'].metpy.latitude
ncin_1['longitude']  = ncin_1['longitude'].metpy.longitude
ncin_1['level']  = ncin_1['level'].metpy.vertical
ncin_1['t'] = ncin_1['t'] * mpunit.units.degC
ncin_1['dewpoint'] = ncin_1['dewpoint'] * mpunit.units.degC
ncin_1['r'] = ncin_1['r'] * mpunit.units.dimensionless

# pressures levels in descending order
ncin_1= ncin_1.where((ncin_1['level'] >= 100) & (test['level'] <= 1000), drop=True)
ncin_1=  ncin_1.sortby('level', ascending=False)

# create a profile variable to be filled in
ncin_1['profil'] = (('time', 'level', 'latitude', 'longitude'),np.full_like(ncin_1['t'], fill_value=np.nan))
<xarray.Dataset>
Dimensions:    (longitude: 60, latitude: 41, level: 21, time: 24836)
Coordinates:
  * longitude  (longitude) float32 -5.02 -4.77 -4.52 -4.27 ... 9.23 9.48 9.73
  * latitude   (latitude) float32 51.15 50.9 50.65 50.4 ... 41.65 41.4 41.15
  * level      (level) int32 20 50 100 150 200 250 ... 750 800 850 900 950 1000
  * time       (time) datetime64[ns] 2006-01-01 ... 2022-12-31T18:00:00
Data variables:
    r          (time, level, latitude, longitude) float32 0.03323 ... 0.7702
    t          (time, level, latitude, longitude) float32 -69.07 ... 14.25
    dewpoint   (time, level, latitude, longitude) float32 -90.23 ... 10.28
    profil     (time, level, latitude, longitude) float32 nan nan ... nan nan
Attributes:
    Conventions:  CF-1.6
    history:      2023-07-19 22:32:41 GMT by grib_to_netcdf-2.25.1: /opt/ecmw... 

I'd like to calculate the lifted index (using metpy's lifted_index function), but first I need to calculate the parcel_profile variable. The problem is that this function is a 1D function, according to the documentation. I've made several scripts using either xarray.apply_ufunc or xarray.map_blocks.

  1. whit xarray.apply_ufunc :
def wrapper_parcel_profile(pressure, temperature, dewpoint):
    return mpcalc.parcel_profile(pressure * units.hPa , temperature * units.degC , dewpoint * units.degC ).to('degC')


t_1000 = ncin_1['t'].metpy.sel(level=1000)
dewpoint_1000 = ncin_1['dewpoint'].metpy.sel(level=1000) 
pressure = ncin_1['level'] 

ncin_1['profil'] = xr.apply_ufunc(
    wrapper_parcel_profile,  
    pressure, t_1000, dewpoint_1000,  
    input_core_dims=[['level'], ['time', 'latitude', 'longitude'], ['time', 'latitude', 'longitude']], 
    output_core_dims=[['time', 'level' , 'latitude', 'longitude']],
    vectorize=True,
    dask='parallelized', 
    output_dtypes=[float],
    dask_gufunc_kwargs={'allow_rechunk': True}
)

This script runs because I'm using dask, so if I've understood correctly, as long as I don't run the ncin_1.compute() command, nothing is calculated directly. I get this error message:

ValueError: operands could not be broadcast together with shapes (19,) (24836,41,60) (It must have something to do with ncin_1['level']?)

  1. with xarray.map_block :
def wrapper_parcel_profile(pressure, temperature, dewpoint):
    return mpcalc.parcel_profile( pressure * units.hPa , temperature * units.degC , dewpoint * units.degC).to('degC')

pressure = test['level']
t_1000 = test['t'].metpy.sel(level=1000) * units.degC
dewpoint_1000 = test['dewpoint'].metpy.sel(level=1000) * units.degC
test['profil'] = xr.map_blocks(wrapper_parcel_profile, test ,template= test['t'])

I get this error message when I use "ncin_1.compute()" in my database:

TypeError: Mixing chunked array types is not supported, but received multiple types: {, }

Is my approach the right one? Is it possible to do this simply by staying in an xarray dataset? Are the solutions I've found appropriate? Thanks in advance for your help

nietreil
  • 11
  • 3

2 Answers2

1

I think you're on the right path.

I'm not familiar with parcel_profile, but it looks like it consumes and produces 1D arrays? If so, xr.apply_ufunc with just one core dimension (presumably levels) should work.

apply_ufunc is a powerful function, with many options. The trick to getting it to work is to start with a very simple case that you understand, then increase the generality until your code can handle your full problem. e.g.

  1. make it work in 1D for in-memory (pre-loaded) data
  2. make it work in ND for in-memory data (so it needs to broadcast in your case)
  3. make it work in 1D for lazy dask data
  4. make it work in ND for lazy dask data
  5. check that your data is chunked in such a way that your function can handle it (e.g. are the core dims chunked?)

One resource that is very helpful is the xarray tutorial documentation on apply_ufunc here.

Other comments:

  • You almost certainly shouldn't be using dask_gufunc_kwargs={'allow_rechunk': True},
  • You don't need vectorized=True for applying a function that already understands arrays, that option is for scalar operations.
ThomasNicholas
  • 1,273
  • 11
  • 21
1

After some research, there's really no way to optimize and parallelize calculations with dask or xarray functions. Or, with the knowledge I have, I'm currently unable to do so.

Indeed, the particularity of this function in 1D is that the input and output have different dimensions and this requires numpy objects (cf: they do .compute() at each input making the dask and xarray functions useless). My solution is as follows: I used this script which parallelizes the loop and I optimized as best as possible, it took me about 2 hours with my pc ( intel i7 12th gen 20vcpu at 3ghz / 32go ram).

Suppose you save the code in workers.py, so it will look like this:

import metpy as mp
import metpy.calc as mpcalc
import metpy.units as mpunit

def calculer_profil_pour_coordonnees(args):
    t, lat, lon, temperature, dewpoint, levels = args
    parcel_profile = mpcalc.parcel_profile(levels, temperature, dewpoint).to('degC')
    return lat, lon, t, parcel_profile.magnitude

Now you can call the function in your scripts and process the loop in multiprocessing:

from multiprocessing import Pool
import worker #name of your .py file with the loop function


temperature_1000 = test_2['t'].metpy.sel(level=1000).data.compute()
dewpoint_1000 = test_2['dewpoint'].metpy.sel(level=1000).data.compute()

levels = [1000 , 950 , 900 , 850 , 800,  750  ,700 , 650  ,600  ,550  ,500,  450,  400,  350,
  300 , 250 , 200 , 150 , 100] * mpunit.units.hPa

time_dim = ncin_1.dims['time']
lat_dim =  ncin_1.dims['latitude']
lon_dim = ncin_1.dims['longitude']

def compute_parcel_profile(temp_chunk, dewpoint_chunk, levels):
    profiles = np.empty((time_dim, len(levels), lat_dim, lon_dim))

    args_list = [(t, lat, lon, temp_chunk[t, lat, lon], dewpoint_chunk[t, lat, lon], levels)
                 for t in range(time_dim) for lat in range(lat_dim) for lon in range(lon_dim)]

    num_processors = 10
    with Pool(processes=num_processors) as p:
        results = list(tqdm(p.imap(worker.calculer_profil_pour_coordonnees, args_list), total=len(args_list)))
        
    for lat, lon, t, profile in results:
        profiles[t, :, lat, lon] = profile

    return profiles

result = compute_parcel_profile(temperature_1000, dewpoint_1000, levels)

profiles_dataset = xr.Dataset(
    {
        'profile': (('time', 'level', 'latitude', 'longitude'), result)
    },
    coords={
        'time': ncin_1.time,
        'level': ncin_1.level,
        'latitude': ncin_1.latitude,
        'longitude': ncin_1.longitude
    }
).chunk()  * mpunit.units.degC

profiles_dataset['profile'].attrs['long_name'] = 'profile_a_parcel'

merged_dataset = ncin_1.merge(profiles_dataset)

I hope the metpy package will be improved in the future especially for metpy.calc with functions like parcel_profile, mixed_layer_cape_cin or lifted_index. Don't take my message as the perfect solution, I've tried to optimize as best I can. One solution would be to use instances like aws if you don't have enough power (but this comes at a cost).

nietreil
  • 11
  • 3