Search code examples
pythonpython-xarrayargmax

Looking to use output from argmax to select level in another array


I am attempting to use the ERA-Interim vegetation cover fraction to determine the dominant vegetation cover. Essentially, there are 4 variables: low vegetation cover fraction, high vegetation cover fraction, low vegetation type and high vegetation type.

ERAI_high_veg_frac = ERAI_ds[['CVH_GDS4_SFC_S123']]
ERAI_high_veg_frac = ERAI_high_veg_frac.rename({'CVH_GDS4_SFC_S123':'vegetation_cover'})
ERAI_high_veg_frac = ERAI_high_veg_frac.sel(initial_time0_hours=slice('2010-06-01','2010-06-01'),drop=True)

ERAI_low_veg_frac = ERAI_ds[['CVL_GDS4_SFC_S123']]
ERAI_low_veg_frac = ERAI_low_veg_frac.rename({'CVL_GDS4_SFC_S123':'vegetation_cover'})
ERAI_low_veg_frac = ERAI_low_veg_frac.sel(initial_time0_hours=slice('2010-06-01','2010-06-01'),drop=True)

ERAI_high_veg_type = ERAI_ds[['TVH_GDS4_SFC_S123']]
ERAI_high_veg_type = ERAI_high_veg_type.rename({'TVH_GDS4_SFC_S123':'vegetation_type'})
ERAI_high_veg_type = ERAI_high_veg_type.sel(initial_time0_hours=slice('2010-06-01','2010-06-01'),drop=True)

ERAI_low_veg_type = ERAI_ds[['TVL_GDS4_SFC_S123']]
ERAI_low_veg_type = ERAI_low_veg_type.rename({'TVL_GDS4_SFC_S123':'vegetation_type'})
ERAI_low_veg_type = ERAI_low_veg_type.sel(initial_time0_hours=slice('2010-06-01','2010-06-01'),drop=True)

I am using xr.concat() and argmax to determine the type of vegetation that is dominant, based on the vegetation fraction in each of low and high vegetation cover fraction:

ERAI_vegfrac = xr.concat([ERAI_high_veg_frac,ERAI_low_veg_frac],pd.Index(['High Veg Frac','Low Veg Frac'],name='veg_cover'),fill_value=np.nan)
ERAI_domvegfrac = ERAI_vegfrac.max('veg_cover',skipna='True')
ERAI_domvegidx = ERAI_vegfrac.argmax('veg_cover',skipna='True')

ERAI_domvegidx = ERAI_domvegidx['vegetation_cover']

This produces an xarray.DataArray called "vegetation cover' with the following structure:

<xarray.DataArray 'vegetation_cover' (initial_time0_hours: 1, g4_lat_1: 256, g4_lon_2: 512)>
dask.array<nanarg_agg-aggregate, shape=(1, 256, 512), dtype=int64, chunksize=(1, 256, 512), chunktype=numpy.ndarray>
Coordinates:
  * initial_time0_hours  (initial_time0_hours) datetime64[ns] 2010-06-01
  * g4_lon_2             (g4_lon_2) float32 0.0 0.7031 1.406 ... 358.6 359.3
  * g4_lat_1             (g4_lat_1) float32 89.46 88.77 88.07 ... -88.77 -89.46

I have also xr.concat() the vegetation types to produce an xarray.DataArray ('vegetation type') in a similar fashion:

ERAI_vegtype = xr.concat([ERAI_high_veg_type,ERAI_low_veg_type],pd.Index(['High Veg Type','Low Veg Type'],name='veg_type'),fill_value=np.nan)
ERAI_vegtype = ERAI_vegtype['vegetation_type']

Which has the following structure:

<xarray.DataArray 'vegetation_type' (veg_type: 2, initial_time0_hours: 1,g4_lat_1: 256, g4_lon_2: 512)>
dask.array<concatenate, shape=(2, 1, 256, 512), dtype=float32, chunksize=(1, 1, 256, 512), chunktype=numpy.ndarray>
Coordinates:
  * initial_time0_hours  (initial_time0_hours) datetime64[ns] 2010-06-01
  * g4_lon_2             (g4_lon_2) float32 0.0 0.7031 1.406 ... 358.6 359.3
  * g4_lat_1             (g4_lat_1) float32 89.46 88.77 88.07 ... -88.77 -89.46
  * veg_type             (veg_type) object 'High Veg Type' 'Low Veg Type'

My aim is to use the boolean values in ERAI_domvegidx (0,1) to select the level (veg_type) of ERAI_vegtype, such that if ERAI_domvegidx = 0, then we will select veg_type='High Veg Type', and if ERAI_domvegidx=1, then we will select veg_type='Low Veg Type':

ERAI_domveg = ERAI_vegtype.isel(veg_type=ERAI_domvegidx)

However this returns a "TypeError: unexpected indexer type for VectorizedIndexer: dask.array<nanarg_agg-aggregate, shape=(1, 256, 512), dtype=int64, chunksize=(1, 256, 512), chunktype=numpy.ndarray>"

Thus I was wondering if it is possible to use the boolean index "ERAI_domvegidx" to select the level of the "veg_type" coordinate in "ERAI_vegtype"?


Solution

  • argmax returns an ordinary array of integers, rather than the DataArray of strings that you need

    You could use .where() to create a new DataArray with the same dimensions as ERAI_vegtype

    veg_type_names = ['High Veg Type', 'Low Veg Type']
    ERAI_domveg = ERAI_vegtype.where(
        ERAI_domvegidx == 0,
        other=ERAI_vegtype.sel(veg_type=veg_type_names[1])
    ).where(
        ERAI_domvegidx == 1,
        other=ERAI_vegtype.sel(veg_type=veg_type_names[0])
    )
    

    This will give NaN's for cases that are not High or Low veg types.