Search code examples
numpycluster-computingdaskpython-xarrayera5

How to handle large xarray/dask datasets to minimize computation time or running out of memory to store as yearly files (ERA5 dataset)


Currently, I am using ERA5-land data to calculate wind related variables. While I am able to calculate what I want, I do struggle with an efficient implementation to lift this heavy data in a feasible way. This data is stored on a "Supercomputer" that I can access and run my script on. A couple of CPU and Memory options are available but most of the time I go with either 4CPUs/18GB or 7CPUs/32GB. The data comes in monthly files that I open as a yearly files. The total number of years is 74 but I figured going year by year is the better approach.

Here is my code:

# Imports
import geopandas as gpd
from osgeo import gdal
import numpy as np
import matplotlib.pyplot as plt
import glob
import pickle
import xarray as xr
import time 

# setting up my workers (e.g. 7)
from dask.distributed import Client
client = Client(n_workers=7)
client

# Declaring variables
directory = 'PATH_TO_LOCAL_DIR'

years  = np.arange(1950,2023+1)

lon0 =  129.7
lat0 = - 27.9
lonbounds = [113.0,145.0]
latbounds = [-36.0,-16.0]

d     = 300e-6
z     = 10
g     = 9.8
z0    = 1e-2
rho_s = 2650
rho_f = 1.2
kappa = 0.4
phi = 0.6
fluxconstant      = 5
thresholdconstant = 0.082
usimths = np.linspace(0.01,0.8,80,dtype=np.float32)
r = 6371.229*1e+3 
s2yr = 1/60/60/24/365

uabinwidth = 3*np.pi/180
uabine = np.arange(-180*np.pi/180,(180+1e-4)*np.pi/180,uabinwidth)
uabinm = uabine[1:]-uabinwidth/2


# Opening the data. Currently I just grabe the first year worth of data
era5  = xr.open_mfdataset(glob.glob('u10_era5-land_oper_sfc_%d*.nc'%(years[i],years[i])))
era5  = era5.reindex(latitude=list(reversed(era5.latitude)))
era5  = era5.sel(longitude=slice(lonbounds[0],lonbounds[-1]),latitude=slice(latbounds[0],latbounds[-1]))       
era5v = xr.open_mfdataset(glob.glob('v10_era5-land_oper_sfc_%d*.nc'%(years[i],years[i])))
era5v = era5v.reindex(latitude=list(reversed(era5v.latitude)))
era5v = era5v.sel(longitude=slice(lonbounds[0],lonbounds[-1]),latitude=slice(latbounds[0],latbounds[-1]))
era5  = era5.merge(era5v)

The output of era5 at this point is:

 <xarray.Dataset>
 Dimensions:    (longitude: 321, latitude: 201, time: 8759)
 Coordinates:
   * longitude  (longitude) float32 113.0 113.1 113.2 113.3 ... 144.8 144.9 145.0
   * latitude   (latitude) float32 -36.0 -35.9 -35.8 -35.7 ... -16.2 -16.1 -16.0
   * time       (time) datetime64[ns] 1950-01-01T01:00:00 ... 1950-12-31T23:00:00
 Data variables:
     u10        (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     v10        (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>

Next I do:

# Calculation 
era5 = era5.assign(ua_from=np.arctan2(-era5.v10,-era5.u10))#.compute() <--- not sure if compute in between is a good idea?

for j in range(len(usimths)):
    era5 = era5.assign(temp=(((era5.u10**2+era5.v10**2)*(kappa/np.log(z/z0))**2-usimths[j]**2)*fluxconstant*usimths[j]/g*rho_f/rho_s).where(((era5.u10**2+era5.v10**2)**0.5*(kappa/np.log(z/z0)))>usimths[j],0))#.compute()
    era5['qm_{}'.format(str(round(usimths[j],2)))] = era5.temp
    era5 = era5.drop_vars('temp')

With the output of era5 being:

 Output of era5:
 <xarray.Dataset>
 Dimensions:    (longitude: 321, latitude: 201, time: 8759)
 Coordinates:
   * longitude  (longitude) float32 113.0 113.1 113.2 113.3 ... 144.8 144.9 145.0
   * latitude   (latitude) float32 -36.0 -35.9 -35.8 -35.7 ... -16.2 -16.1 -16.0
   * time       (time) datetime64[ns] 1950-01-01T01:00:00 ... 1950-12-31T23:00:00
 Data variables: (12/83)
     u10        (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     v10        (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     ua_from    (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     qm_0.01    (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     qm_0.02    (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     qm_0.03    (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     ...         ...
     qm_0.75    (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     qm_0.76    (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     qm_0.77    (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     qm_0.78    (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     qm_0.79    (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>
     qm_0.8     (time, latitude, longitude) float32 dask.array<chunksize=(52, 100, 166), meta=np.ndarray>

Finally:

# this part takes some time and I am still unsure if I should compute or not in between
ua_from = era5.ua_from.compute().values #<-- need these for below
bin_arrays_mean = []
bin_arrays_sum = []
for ident in range(len(uabine)-1):
    era5tempid = era5.drop_vars(["u10","v10","ua_from"])
    era5tempid = era5tempid.where((ua_from >= uabine[ident]) & (ua_from < uabine[ident+1]) & (era5tempid > 0))
    bin_arrays_mean.append(era5tempid.mean(dim="time", skipna=True))#.compute())
    bin_arrays_sum.append(era5tempid.sum(dim="time", skipna=True))#.compute())

# merge and rearrange for desired dataset
era5ang_mean = xr.concat(bin_arrays_mean,dim="ang_bins").assign_coords({"ang_bins":np.arange(120, dtype=np.uint32)})
era5ang_sum = xr.concat(bin_arrays_sum,dim="ang_bins").assign_coords({"ang_bins":np.arange(120, dtype=np.uint32)})

qm_arrays_mean = []
for var in era5ang_mean:
    qm_arrays_mean.append(era5ang_mean[var].rename("qm"))

qm_arrays_sum = []
for var in era5ang_sum:
    qm_arrays_sum.append(era5ang_sum[var].rename("qm"))

era5_bins_mean = xr.concat(qm_arrays_mean, dim="usimths").assign_coords({"usimths":usimths})
era5_bins_sum = xr.concat(qm_arrays_sum, dim="usimths").assign_coords({"usimths":usimths})

era5_bins_comb = xr.Dataset({"qmmean":era5_bins_mean, "qmsum":era5_bins_sum})

Output of era5_bins_comb (NOTE: Output below is (computed) from a run of the above but with ONLY one longitude selected which took close to an hour or so):

<xarray.Dataset>
Dimensions:    (latitude: 201, ang_bins: 120, usimths: 80)
Coordinates:
    longitude  float32 113.0
  * latitude   (latitude) float32 -36.0 -35.9 -35.8 -35.7 ... -16.2 -16.1 -16.0
  * ang_bins   (ang_bins) uint32 0 1 2 3 4 5 6 7 ... 113 114 115 116 117 118 119
  * usimths    (usimths) float32 0.01 0.02 0.03 0.04 0.05 ... 0.77 0.78 0.79 0.8
Data variables:
    qmmean     (usimths, ang_bins, latitude) float32 nan nan nan ... nan nan nan
    qmsum      (usimths, ang_bins, latitude) float32 0.0 0.0 0.0 ... 0.0 0.0 0.0

In the last step I want to save the file using era5_bins_comb.to_netcdf().

I am afraid I am not too familiar with how xarray and dask work. Specifically, when and when not to use .compute(). After trying a couple of things with computing in between the above is where I ended up now. The part that takes the most time (like an hour or maybe more) is the for ident in range(len(uabine)-1): loop. The rest runs fairly fast (probably because I do not compute anything before that). Also, the last part of saving the dataset to temporarily store it, to combine it with the other 73 years later, seems to be tricky as well. It looks like I could run out of memory and/or it will take forever as well. The last thing I tested was to use era5_bins_comb.isel(usimth=0).compute() which at least finished faster but obviously would require me to iterate over all usimths and store them separately. Alternatively, I have also access to more cpus and mem e.g. 14CPUs/63GB (used to test the era5_bins_comb.isel(usimth=0).compute()) or 28CPUs/126.

It just feels like I could improve my code, skip some loops and/or make better use of the computational power. Then again, I don't know much about big data, dask and the principles one should follow. What improvements can I make so I can get the data without waiting forever and wasting resources?

Edit

Just some extra info: I ran my code again today with 14CPUs and timed the for ident in range(len(uabine)-1): loop. It takes roughly 6 minutes for one loop which would be 12hrs (12*74=888hrs for all year files!). Also the task stream in the dask dashboard looks really unoptimised. Trying to run a nested loop iterating over all the 80 variables and all the 120 uabine takes 20 seconds for each nested loop, so 20*120*80/60/60=53hrs which is even worse. Maybe, I need to find another way to avoid the uabine loop that filters my data based on the wind direction they are blowing from.


Solution

  • My supervisor came up with an efficient solution to the problem that computed everything over all years in merely 13hrs! His solution was:

    hourcount = 0
    from xhistogram.xarray import histogram
    for i in range(len(years)):
         for j in range(len(months)):
                # get u
                era5  = xr.open_mfdataset(glob.glob('%d/u10_era5-land_oper_sfc_%d%02d*.nc'%(years[i],years[i],months[j]))[0])
                era5  = era5.reindex(latitude=list(reversed(era5.latitude)))
                era5  = era5.sel(longitude=slice(lonbounds[0],lonbounds[-1]),latitude=slice(latbounds[0],latbounds[-1]))
                
                # get v
                era5v = xr.open_mfdataset(glob.glob('%d/v10_era5-land_oper_sfc_%d%02d*.nc'%(years[i],years[i],months[j]))[0])
                era5v = era5v.reindex(latitude=list(reversed(era5v.latitude)))
                era5v = era5v.sel(longitude=slice(lonbounds[0],lonbounds[-1]),latitude=slice(latbounds[0],latbounds[-1]))
                
                # merge u and v
                era5  = era5.merge(era5v)
                
                era5 = era5.assign(ua=np.arctan2(-era5.v10,-era5.u10)) # from ONLY
                era5 = era5.assign(us=((era5.u10**2+era5.v10**2)**0.5*kappa/np.log(z/z0))) # make us
                era5 = era5.drop_vars(["u10","v10"]) # drop uv
                
                era5['us'] = era5.us.expand_dims({'usithr':usimths},axis=3) # expand us over usimthr dim
                
                era5 = era5.chunk({'longitude':64,'latitude':40,'time':len(era5.time),'usithr':1})
                
                era5['us'] = ((era5.us**2-era5.usithr**2)*era5.usithr*fluxk).where(era5.us>era5.usithr,0)
                era5 = era5.rename_vars({'us':'qm'})
                
                temp = histogram(era5.ua,bins=uabine,weights=era5.qm,dim=['time']).compute()/len(era5.time)
                
                if hourcount==0:
                    final = temp
                else:
                    final = (hourcount*final+len(era5.time)*temp)/(hourcount+len(era5.time))
        
                hourcount += len(era5.time)
                
                print(hourcount,end='\r')
        
            if i == 0:
                final.to_netcdf(directory+"flux_bins_%d.nc"%(years[i]))
            if i % 5 == 0:
                final.to_netcdf(directory+"flux_bins_%d.nc"%(years[i]))
                
        final.to_netcdf(directory+"flux_bins_all.nc")
    

    He did:

    • focus on monthly files rather than the yearly ones
    • expand the ds by usithr (instead of looping over them)
    • chunking the data
    • using the histogram function from xhistogram.xarray

    Maybe it could be even improved more e.g. chunking or implementing the ds.groupby_bins and arguments for open_mfdatasetas suggested by @RichSignell.