I am interpolating data from the oceanic component of CMIP6 models to a 2x2 grid. The field has a dim of (time, nav_lat, nav_lon) and nan values in continent. Here, nav_lon and nav_lat are two-dimensional curvilinear grid. I can do the interpolation using griddata from scipy, but I have to use a loop over time. The loop makes it pretty slow if the data has thousands of time records. My question is how to vectorize the interpolation over time.
The following is my code:
import xarray as xr
import numpy as np
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
source = xr.open_dataset('data/zos_2850.nc',decode_times=False)
# obtain old lon and lat (and put lon in 0-360 range)
# nav_lon is from -180 to 180, not in 0-360 range
loni, lati = source.nav_lon.values%360, source.nav_lat.values
# flatten the source coordinates
loni_flat, lati_flat = loni.flatten(), lati.flatten()
# define a 2x2 lon-lat grid
lon, lat = np.linspace(0,360,181), np.linspace(-90,90,91)
# create mesh
X, Y = np.meshgrid(lon,lat)
# loop over time
ntime = len(source.time)
tmp = []
for t in range(ntime):
print(t)
var_s = source.zos[t].values
var_s_flat = var_s.flatten()
# index indicates where they are valid values
index = np.where(~np.isnan(var_s_flat))
# remap the valid values to the new grid
var_t = griddata((loni_flat[index],lati_flat[index]),var_s_flat[index], (X,Y),
method='cubic')
# interpolate mask using nearest
maskinterp = griddata((loni_flat,lati_flat),var_s_flat, (X,Y), method='nearest')
# re-mask interpolated data
var_t[np.isnan(maskinterp)] = np.nan
tmp.append(var_t)
# convert to DataArray
da = xr.DataArray(data=tmp,
dims=["time","lat","lon"],
coords=dict(lon=(["lon"], lon),lat=(["lat"], lat),time=source['time']))