Search code examples
pythonpython-xarray

plot where xth variable is maximum with xarray


I'm attempting to create a plot which shows regions where one regression predictor is more important than all others, particularly where its relative importance (obtained with pingouin's linear_regression(relimp=True)) is the maximum value for the set of predictors. This data is stored like so:

data = xr.open_dataset('.../file.nc')
# opens file.nc as an xarray dataframe; file contains v variables on d dimensions
# example variable: relimppct. this variable's dimensions are lat, lon, modelName, and varname. 
# lat and lon are simply latitude and longitude; modelName represents the 
# climate model the regression data was created with and varname represents the regressor used.
# so, to create a plot with panels for each relimppct regressor averaged across all models:
fig, axs = plt.subplots(1, len(data.varname), 
                        subplot_kw={'transform': ccrs.PlateCarree(), 'projection': ccrs.PlateCarree()})
for i in range(len(data.varname)):
   data.relimppct.mean(dim='modelName').isel(varname=i).plot(ax=axs[i])
plt.show()

What I'd like to do is create a single-panel plot with filled contours showing where each regression coefficient is most dominant, with a single categorical value for each. As an example, see figure 4 from Naud et al. 2023:

Naud et al. 2023 fig 4

Naud et al. 2023

Their plot shows where each of their meteorological variables is most correlated with the presence of clouds. My immediate thoughts are a) that I will need to use xarray's where function, as something along the lines of data.relimppct.where(data.relimppct.isel(varname=i)==data.relimppct.max(dim='varname')), and b) that I will have to step outside the built-in plotting framework. Please share any suggestions or questions.


Solution

  • My idea is to rename the variable names to numerics and then use to_array() to merge different variables into a single variable. Then I find out the position of the maximum one using idxmax().

    Here I have an example dataset of different variables having yearly maximum temperatures and plot which year has the largest maximum temperature across the six years.

    import numpy as np
    import xarray as xr
    
    data = xr.tutorial.load_dataset("ersstv5")
    yearly_max = data.sst.sel(time=slice("2000-01-01", "2010-12-31")).groupby("time.year").max("time")
    
    # the dataset with multiple variables to be found the maximum
    ds = xr.Dataset({"Y2010": yearly_max.sel(year=2010).drop("year"), "Y2008": yearly_max.sel(year=2008).drop("year"), "Y2006": yearly_max.sel(year=2006).drop("year"),
                     "Y2004": yearly_max.sel(year=2004).drop("year"), "Y2002": yearly_max.sel(year=2002).drop("year"), "Y2000": yearly_max.sel(year=2000).drop("year")})
    
    # mapping the variable names to numerics so that can be plotted later
    ds_map = ds.rename({"Y2010": 1, "Y2008": 2, "Y2006": 3, "Y2004": 4, "Y2002": 5, "Y2000": 6})
    cs = ds_map.to_array("which_year").idxmax("which_year").plot(cmap="tab10", levels=np.arange(0.5, 7.5),
                                                                 cbar_kwargs={"label": "Year of Maximum SST", "orientation": "horizontal"})
    cs.colorbar.set_ticks(np.arange(1, 7))
    cs.colorbar.set_ticklabels(["2010", "2008", "2006", "2004", "2002", "2000"])
    

    The result figure is like: year with max temperature