I am new to Xarray and Dask and trying to access multiple netCDF files that store global ocean current velocities on 3H interval. Each netCDF file covers one time interval of gridded data of 1/4 degree resolution:
NetCDF dimension information:
Name: time
size: 1
type: dtype('float32')
_FillValue: 9.96921e+36
units: 'days since 1950-01-01 00:00:00 UTC'
calendar: 'julian'
axis: 'T'
Name: lat
size: 720
type: dtype('float32')
_FillValue: 9.96921e+36
units: 'degrees_north'
axis: 'Y'
Name: lon
size: 1440
type: dtype('float32')
_FillValue: 9.96921e+36
units: 'degrees_east'
axis: 'X'
NetCDF variable information:
Name: eastward_eulerian_current_velocity, northward_eulerian_current_velocity
dimensions: ('time', 'lat', 'lon')
size: 1036800
type: dtype('float32')
_FillValue: 9.96921e+36
coordinates: 'lon lat'
horizontal_scale_range: 'greater than 100 km'
temporal_scale_range: '10 days'
units: 'm s-1'
For the evaluation of travel time of routes in a ship route planner, I am trying to find the fastest way to select the current velocity values (northward and eastward) for a specific longitude/latitude combination and time index.
Current approach:
First, I open multiple datasets using xarray.open_mfdataset()
, and store it with xarray.Dataset.to_netcdf()
after preprocessing, while keeping type float32. Then, I re-open the dataset and chunk it with autochunking by Dask. For concatenating 50 days of data (n_days = 50
) this results in the following chunks:
- 'time': 200
- 'lon': 288
- 'lat': 240
200 * 288 * 240 * 4 bytes (float32) * 2 variables = 110.6 MB (correct?)
Storing and opening these files are expensive when n_days
is large. For n_days = 50
, the stored netCDF file is 3.08 GB.
def load_current_data(start_date, n_days):
# start_date: datetime(2016, 1, 1)
# n_days: 50 --> loading 50 * 8 = 400 netCDF files
# ds_fp: dataset filepath
# local_paths: netCDF file pathslist
if os.path.exists(ds_fp):
return xr.open_dataset(ds_fp, chunks={'time': 'auto', 'lon': 'auto', 'lat': 'auto'})
# Try opening current data locally, otherwise download from FTP server
try:
ds = xr.open_mfdataset(local_paths,
parallel=True,
combine='by_coords',
preprocess=convert_to_knots)
except FileNotFoundError:
# Download files from FTP server and save to local_paths
ds = xr.open_mfdataset(local_paths,
parallel=True,
combine='by_coords',
preprocess=convert_to_knots)
ds.to_netcdf(ds_fp)
return xr.open_dataset(ds_fp, chunks={'time': 'auto', 'lon': 'auto', 'lat': 'auto'})
Function for preprocessing data in xarray.open_mfdataset()
:
def convert_to_knots(ds):
ds.attrs = {}
arr2d = np.float32(np.ones((720, 1440)) * 1.94384)
ds['u_knot'] = arr2d * ds['eastward_eulerian_current_velocity']
ds['v_knot'] = arr2d * ds['northward_eulerian_current_velocity']
ds = ds.drop_vars(['eastward_eulerian_current_velocity',
'eastward_eulerian_current_velocity_error',
'northward_eulerian_current_velocity',
'northward_eulerian_current_velocity_error'])
return ds
load_current_data()
returns a chunked xarray.Dataset
, see below.
<xarray.Dataset>
Dimensions: (lat: 720, lon: 1440, time: 400)
Coordinates:
* lon (lon) float32 -179.875 -179.625 -179.375 ... 179.625 179.875
* lat (lat) float32 -89.875 -89.625 -89.375 ... 89.375 89.625 89.875
* time (time) object 2016-01-01 00:00:00 ... 2016-02-19 21:00:00
Data variables:
u_knot (time, lat, lon) float32 dask.array<chunksize=(200, 240, 288), meta=np.ndarray>
v_knot (time, lat, lon) float32 dask.array<chunksize=(200, 240, 288), meta=np.ndarray>
To get the actual current velocities at a grid point I wrote the function below. It takes a lon/lat index and the datetime of the required netCDF file (e.g. 2016-01-01 00:00:00). This function loads the data with load_current_data()
and persists it into memory. This is, however, only possible if it fits into the memory.
Finally the values corresponding to the lon/lat index and datetime arguments are selected with xarray.Dataset.isel().load()
, see below.
I use dask.cache to cache computations.
import datetime
from dask.cache import Cache
class CurrentData:
def __init__(self, start_date, n_days):
self.start_date = start_date)
self.n_days = n_days
# Persist data into memory
self.ds = load_current_data(self.start_date, self.n_days).persist()
cache = Cache(2e9) # Leverage two gigabytes of memory
cache.register() # Turn cache on globally
def get_grid_pt_current(self, date_in, lon_idx, lat_idx):
delta = date_in - self.start_date
day_idx = delta.seconds // 3600 // 3
vals = self.ds.isel(time=day_idx, lat=lat_idx, lon=lon_idx).load()
u = float(vals['u_knot'])
v = float(vals['v_knot'])
return u_pt, v_pt
Is it possible to improve the performance of this code? And, specifically:
- The way I use
persist()
is more expensive for larger datasets. Is there a better way to efficiently acces data withoutpersist()
or by using it differently? I also triedload()
, to load the data into memory. This is much faster, however, it does not work for datasets larger than memory. - Could setting
compute=False
inxarray.Dataset.to_netcdf()
help improving the performance? The stored netCDF file is much smaller, henc store and opening is faster. Values need to be computed at a later stage, but I didn't managed to do this. (Using the returneddask.delayed
object?)